diff --git a/.bazelrc b/.bazelrc index cf15d0976b1..f2aa3ac447b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl -c opt +# config to build OneDNN backend with a user specified threadpool. +build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_threadpool --define=build_with_mkldnn_threadpool=true +build:mkl_threadpool -c opt # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true @@ -163,6 +168,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1 build:dbg --config=opt -c dbg # for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360 build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON +# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 +build:dbg --copt -DDEBUG_BUILD build:tensorrt --action_env TF_NEED_TENSORRT=1 @@ -233,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++ build:c++1z --config=c++17 -# Enable using platform specific build settings +# Enable using platform specific build settings, except when cross-compiling for +# mobile platforms. build --enable_platform_specific_config +build:android --noenable_platform_specific_config +build:ios --noenable_platform_specific_config # Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. +build:android --copt=-w +build:ios --copt=-w build:linux --copt=-w build:macos --copt=-w build:windows --copt=/w @@ -256,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include # TF_SYSTEM_LIBS do not work on windows. # By default, build TF in C++ 14 mode. +build:android --cxxopt=-std=c++14 +build:android --host_cxxopt=-std=c++14 +build:ios --cxxopt=-std=c++14 +build:ios --host_cxxopt=-std=c++14 build:linux --cxxopt=-std=c++14 build:linux --host_cxxopt=-std=c++14 build:macos --cxxopt=-std=c++14 diff --git a/.github/bot_config.yml b/.github/bot_config.yml new file mode 100644 index 00000000000..88c737f41e2 --- /dev/null +++ b/.github/bot_config.yml @@ -0,0 +1,87 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +# A list of assignees +assignees: + - amahendrakar + - ravikyram + - Saduf2019 +# A list of assignees for compiler folder +compiler_assignees: + - joker-eph +# Cuda Comment +cuda_comment: > + From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries: + * For TF-GPU - See point 1 + * For TF-CPU - See point 2 + ----------------------------------------------------------------------------------------------- + + **1. Installing **TensorFlow-GPU** (TF) prebuilt binaries** + + + Make sure you are using compatible TF and CUDA versions. + Please refer following TF version and CUDA version compatibility table. + + | TF | CUDA | + + | :-------------: | :-------------: | + + | 2.1.0 - 2.2.0 | 10.1 | + + | 1.13.1 - 2.0 | 10.0 | + + | 1.5.0 - 1.12.0 | 9.0 | + + * If you have above configuration and using _**Windows**_ platform - + * Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable. + * Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup). + * If you have above configuration and using _**Ubuntu/Linux**_ platform - + * Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable. + * Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup). + * If error still persists then, apparently your CPU model does not support AVX instruction sets. + * Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements). + + ----------------------------------------------------------------------------------------------- + + **2. Installing **TensorFlow** (TF) CPU prebuilt binaries** + + + *TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.* + + + Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load. + + Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below: + + * Try Google Colab to use TensorFlow. + * The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version. + * It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task. + * All you need is a good internet connection and you are all set. + * Try to build TF from sources by changing CPU optimization flags. + + *Please let us know if this helps.* + +windows_comment: > + From the stack trace it looks like you are hitting windows path length limit. + * Try to disable path length limit on Windows 10. + * Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/) + + Please let us know if this helps. diff --git a/README.md b/README.md index 27032043e07..a76b1bfd0b7 100644 --- a/README.md +++ b/README.md @@ -103,17 +103,17 @@ open-source software development: ### Official Builds -Build Type | Status | Artifacts ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA -**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) -**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) -**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) +Build Type | Status | Artifacts +------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA +**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) +**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) ### Community Supported Builds @@ -142,6 +142,7 @@ Build Type | Status * [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) +* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) * [TensorFlow Blog](https://blog.tensorflow.org) * [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml) * [TensorFlow Twitter](https://twitter.com/tensorflow) diff --git a/RELEASE.md b/RELEASE.md index b5d088821e4..6f3aa94c203 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,185 @@ +# Release 2.3.0 + +## Breaking Changes + +* `tf.image.extract_glimpse` has been updated to correctly process the case + where `centered=False` and `normalized=False`. This is a breaking change as + the output is different from (incorrect) previous versions. Note this + breaking change only impacts `tf.image.extract_glimpse` and + `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of + `tf.compat.v1.image.extract_glimpse` does not change. The behavior of + exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved + models will not be impacted. + +# Release 2.1.1 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) +* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x + +# Release 2.0.2 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) + +# Release 1.15.3 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) + +# Release 2.2.0 + +TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). + +Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated. + +## Major Features and Improvements + +* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable. +* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines. +* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md). +* `tf.distribute`: + * Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training. + * Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy` + * Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this. + * Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage. + * Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation. + * Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental. + * Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks. +* `tf.keras`: + * `Model.fit` major improvements: + * You can now use custom training logic with `Model.fit` by overriding `Model.train_step`. + * Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc) + * See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`. + * SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of + relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models. + This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to + manually call `Model._set_inputs` when using Custom Training Loops(CTLs). + * Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator. + This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the + `DataAdapter` doesn't need to call the Model is probably preferable. + * The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers) + * Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode. + +* `tf.lite`: + * Enable TFLite experimental new converter by default. +* XLA + * XLA now builds and works on windows. All prebuilt packages come with XLA available. + * XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction +) with “compile or throw exception” semantics on CPU and GPU. + +## Breaking Changes +* `tf.keras`: + * In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer. + * Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function. +* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`. +* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release. +* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`. +* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed. + +## Known Caveats +* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3. + +## Bug Fixes and Other Changes +* `tf.data`: + * Removed `autotune_algorithm` from experimental optimization options. +* TF Core: + * `tf.constant` always creates CPU tensors irrespective of the current device context. + * Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution. + * For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`. + * `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now. + * Set as much partial shape as we can infer statically within the gradient impl of the gather op. + * Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy. + * Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions. + * Support `back_prop=False` in `while_v2` but mark it as deprecated. + * Improve error message when attempting to use `None` in data-dependent control flow. + * Add `RaggedTensor.numpy()`. + * Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions. + * Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension. + * Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged. + * Allow `batch_dims==rank(indices)` in `tf.gather`. + * Add support for bfloat16 in `tf.print`. +* `tf.distribute`: + * Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`. +* `tf.keras`: + * Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop. + * Allow `pathlib.Path` paths for loading models via Keras API. +* `tf.function`/AutoGraph: + * AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`. + * Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info. + * AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph. + * Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x. + * Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes. + * Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`. + * You can now iterate over `RaggedTensors` using a for loop inside `tf.function`. +* `tf.lite`: + * Migrated the `tf.lite` C inference API out of experimental into lite/c. + * Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10 + * TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code. + * Refactors the delegate and delegate kernel sources to allow usage in the linter. + * Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled. + * TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`. + * TFLite's unpack op now supports boolean tensor inputs. + * Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder + * Check for large TFLite tensors. + * Fix GPU delegate crash with C++17. + * Add 5D support to TFLite `strided_slice`. + * Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated. + * Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate + * Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar. + * Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar. + * Expose option to limit the number of partitions that will be delegated to `NNAPI`. + * If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version. +* `tf.random`: + * Various random number generation improvements: + * Add a fast path for default `random_uniform` + * `random_seed` documentation improvement. + * `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right. + * Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson` + * `tf.random.stateless_uniform` now supports unbounded sampling of `int` types. +* Math and Linear Algebra: + * Add `tf.linalg.LinearOperatorTridiag`. + * Add `LinearOperatorBlockLowerTriangular` + * Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation. + * Add `tf.math.sobol_sample` op. + * Add `tf.math.xlog1py`. + * Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`. + * Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`. +* TPU Enhancements: + * Refactor `TpuClusterResolver` to move shared logic to a separate pip package. + * Support configuring TPU software version from cloud tpu client. + * Allowed TPU embedding weight decay factor to be multiplied by learning rate. +* XLA Support: + * Add standalone XLA AOT runtime target + relevant .cc sources to pip package. + * Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later. + * `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs. + * Enable `Igamma`, `Igammac` for XLA. +* Deterministic Op Functionality: + * XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled. + * Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!" +* Tracing and Debugging: + * Add source, destination name to `_send` traceme to allow easier debugging. + * Add traceme event to `fastpathexecute`. +* Other: + * Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852) + * Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`. + * Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi + # Release 2.0.1 ## Bug Fixes and Other Changes diff --git a/SECURITY.md b/SECURITY.md index 6fc2c3aa9cc..f3a6c148b2e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox. It is possible to write models that are secure in a sense that they can safely process untrusted inputs assuming there are no bugs. There are two main reasons -to not rely on this: first, it is easy to write models which must not be exposed +to not rely on this: First, it is easy to write models which must not be exposed to untrusted inputs, and second, there are bugs in any software system of sufficient complexity. Letting users control inputs could allow them to trigger bugs either in TensorFlow or in dependent libraries. @@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a vulnerability in TensorFlow (although it would be a vulnerability of this hypothetical system). -As a general rule, it is incorrect behavior for Tensorflow to access memory it +As a general rule, it is incorrect behavior for TensorFlow to access memory it does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to such behaviors constitute a vulnerability. diff --git a/WORKSPACE b/WORKSPACE index 021ed6d2542..ea741c31c7f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -114,6 +114,14 @@ http_archive( ], ) +http_archive( + name = "person_detect_data", + sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf", + urls = [ + "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip", + ], +) + # Required for dependency @com_github_grpc_grpc load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") diff --git a/configure.py b/configure.py index ac9ed0c4d88..c2850beede6 100644 --- a/configure.py +++ b/configure.py @@ -144,7 +144,7 @@ def write_to_bazelrc(line): def write_action_env_to_bazelrc(var_name, var): - write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) + write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var))) def run_shell(cmd, allow_non_zero=False, stderr=None): @@ -205,7 +205,7 @@ def setup_python(environ_cp): # Get PYTHON_BIN_PATH, default is the current running python. default_python_bin_path = sys.executable ask_python_bin_path = ('Please specify the location of python. [Default is ' - '%s]: ') % default_python_bin_path + '{}]: ').format(default_python_bin_path) while True: python_bin_path = get_from_env_or_user_or_default(environ_cp, 'PYTHON_BIN_PATH', @@ -215,9 +215,10 @@ def setup_python(environ_cp): if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): break elif not os.path.exists(python_bin_path): - print('Invalid python path: %s cannot be found.' % python_bin_path) + print('Invalid python path: {} cannot be found.'.format(python_bin_path)) else: - print('%s is not executable. Is it the python binary?' % python_bin_path) + print('{} is not executable. Is it the python binary?'.format( + python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = '' # Convert python path to Windows style before checking lib and version @@ -236,7 +237,7 @@ def setup_python(environ_cp): default_python_lib_path = python_lib_paths[0] python_lib_path = get_input( 'Please input the desired Python library path to use. ' - 'Default is [%s]\n' % python_lib_paths[0]) + 'Default is [{}]\n'.format(python_lib_paths[0])) if not python_lib_path: python_lib_path = default_python_lib_path environ_cp['PYTHON_LIB_PATH'] = python_lib_path @@ -252,7 +253,7 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) + write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = python_bin_path # If choosen python_lib_path is from a path specified in the PYTHONPATH @@ -266,7 +267,7 @@ def setup_python(environ_cp): with open( os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: - f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) + f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path)) def reset_tf_configure_bazelrc(): @@ -320,11 +321,12 @@ def get_var(environ_cp, Raise the error to avoid infinitely looping. """ if not question: - question = 'Do you wish to build TensorFlow with %s support?' % query_item + question = 'Do you wish to build TensorFlow with {} support?'.format( + query_item) if not yes_reply: - yes_reply = '%s support will be enabled for TensorFlow.' % query_item + yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item) if not no_reply: - no_reply = 'No %s' % yes_reply + no_reply = 'No {}'.format(yes_reply) yes_reply += '\n' no_reply += '\n' @@ -368,7 +370,7 @@ def get_var(environ_cp, print(no_reply) var = False else: - print('Invalid selection: %s' % user_input_origin) + print('Invalid selection: {}'.format(user_input_origin)) return var @@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell( - ['bazel', '--batch', '--bazelrc=/dev/null', 'version']) - for line in curr_version.split('\n'): - if 'Build label: ' in line: - curr_version = line.split('Build label: ')[1] - break + stderr = open(os.devnull, 'wb') + curr_version = run_shell(['bazel', '--version'], + allow_non_zero=True, + stderr=stderr) + if curr_version.startswith('bazel '): + curr_version = curr_version.split('bazel ')[1] min_version_int = convert_version_to_int(min_version) curr_version_int = convert_version_to_int(curr_version) @@ -1009,17 +1011,15 @@ def set_tf_cuda_compute_capabilities(environ_cp): default_cuda_compute_capabilities = native_cuda_compute_capabilities ask_cuda_compute_capabilities = ( - 'Please specify a list of comma-separated ' - 'CUDA compute capabilities you want to ' - 'build with.\nYou can find the compute ' - 'capability of your device at: ' - 'https://developer.nvidia.com/cuda-gpus.\nPlease' - ' note that each additional compute ' - 'capability significantly increases your ' - 'build time and binary size, and that ' - 'TensorFlow only supports compute ' - 'capabilities >= 3.5 [Default is: %s]: ' % - default_cuda_compute_capabilities) + 'Please specify a list of comma-separated CUDA compute capabilities ' + 'you want to build with.\nYou can find the compute capability of your ' + 'device at: https://developer.nvidia.com/cuda-gpus. Each capability ' + 'can be specified as "x.y" or "compute_xy" to include both virtual and' + ' binary GPU code, or as "sm_xy" to only include the binary ' + 'code.\nPlease note that each additional compute capability ' + 'significantly increases your build time and binary size, and that ' + 'TensorFlow only supports compute capabilities >= 3.5 [Default is: ' + '%s]: ' % default_cuda_compute_capabilities) tf_cuda_compute_capabilities = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', ask_cuda_compute_capabilities, default_cuda_compute_capabilities) @@ -1031,8 +1031,23 @@ def set_tf_cuda_compute_capabilities(environ_cp): for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: - print('Invalid compute capability: %s' % compute_capability) - all_valid = False + # We now support sm_35,sm_50,sm_60,compute_70. + sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)', + compute_capability) + if not sm_compute_match: + print('Invalid compute capability: %s' % compute_capability) + all_valid = False + else: + ver = int(sm_compute_match.group(2)) + if ver < 30: + print( + 'ERROR: TensorFlow only supports small CUDA compute' + ' capabilities of sm_30 and higher. Please re-specify the list' + ' of compute capabilities excluding version %s.' % ver) + all_valid = False + if ver < 35: + print('WARNING: XLA does not support CUDA compute capabilities ' + 'lower than sm_35. Disable XLA when running on older GPUs.') else: ver = float(m.group(0)) if ver < 3.0: @@ -1223,7 +1238,8 @@ def is_reduced_optimize_huge_functions_available(environ_cp): only, as of 2019-11-19). TensorFlow needs this flag to massively reduce compile times, but until 16.4 is officially released, we can't depend on it. - See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion + See also + https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion Because it's very annoying to check this manually (to check the MSVC installed versions, you need to use the registry, and it's not clear if Bazel will be @@ -1366,8 +1382,13 @@ def main(): # environment variables. environ_cp = dict(os.environ) - current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION, - _TF_MAX_BAZEL_VERSION) + try: + current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION, + _TF_MAX_BAZEL_VERSION) + except subprocess.CalledProcessError as e: + print('Error checking bazel version: ', e.output.decode('UTF-8').strip()) + raise e + _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) reset_tf_configure_bazelrc() @@ -1385,7 +1406,6 @@ def main(): # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_NEED_MPI'] = '0' - environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f2018220a56..efbdf89ecea 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -517,18 +517,29 @@ package_group( "//perftools/accelerators/xprof/api/...", "//third_party/py/autograph/...", "//third_party/swift/tensorflow/x10/...", + "//third_party/swift/tensorflow_apis/...", "//tensorflow/...", "//tensorflow_estimator/python/estimator/...", "//tensorflow_models/official/...", ], ) -package_group(name = "ndarray_tensor_allow_list") +package_group( + name = "ndarray_tensor_allow_list", + packages = ["//learning/pathways/..."], +) # Packages that use composite tensors or dispatch. # TODO(b/154762408) Remove this package group once it's no longer needed. package_group(name = "composite_tensor_whitelist") +# Packages that use private types symbols, until they are exported. +# TODO(b/154650521) Remove. +package_group( + name = "types_whitelist", + packages = ["//learning/deepmind/tensorflow/replicator/..."], +) + filegroup( name = "intel_binary_blob", data = if_mkl_ml( diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 66ade5c7bd4..12021a294e8 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -85,7 +85,7 @@ tf_cuda_library( ], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//tensorflow:chromiumos": [ ":tf_attrtype", @@ -182,7 +182,7 @@ tf_cuda_library( ":tf_status_internal", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":tf_status", @@ -216,10 +216,11 @@ tf_cuda_library( ], visibility = [ "//tensorflow/c:__subpackages__", + "//tensorflow/compiler/mlir/tensorflow/c:__subpackages__", ], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ "//tensorflow/core:lib", @@ -232,12 +233,13 @@ cc_library( srcs = ["tf_status.cc"], hdrs = ["tf_status.h"], visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":tf_status_internal", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tf_status_internal", "//tensorflow/core:lib", ], }), @@ -259,10 +261,15 @@ cc_library( name = "tensor_interface", hdrs = ["tensor_interface.h"], visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + ], + "//conditions:default": [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + }), ) cc_library( @@ -272,7 +279,7 @@ cc_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ "//tensorflow/core:framework", @@ -286,16 +293,17 @@ cc_library( srcs = ["tf_tensor.cc"], hdrs = ["tf_tensor.h"], visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":tensor_interface", + ":tf_datatype", + ":tf_status", + ":tf_status_helper", + ":tf_tensor_internal", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tensor_interface", - ":tf_datatype", - ":tf_status", - ":tf_status_helper", - ":tf_tensor_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -311,14 +319,15 @@ tf_cuda_library( "tf_tensor_internal.h", ], visibility = ["//tensorflow:internal"], - deps = select({ + deps = [ + ":tensor_interface", + ":tf_datatype", + ":tf_status", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tensor_interface", - ":tf_datatype", - ":tf_status", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:casts", @@ -386,8 +395,14 @@ tf_cuda_library( deps = [ ":tf_status", ":tf_status_internal", - "//tensorflow/core:lib", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + ], + "//conditions:default": [ + "//tensorflow/core:lib", + ], + }), ) tf_cc_test( @@ -426,7 +441,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", @@ -457,7 +472,7 @@ tf_cuda_library( ] + select({ "//tensorflow:android": [ ":c_api_internal", - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api_internal", @@ -484,7 +499,7 @@ tf_cuda_library( ":tf_status_helper", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index e623f30b98c..e9e6d470c68 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { return ret; } -TFE_Context* TFE_CreateContextFromSession(TF_Session* session, - TF_Status* status) { - auto* opts = TFE_NewContextOptions(); - - // Reduce GPU memory allocation, and set appropriate config options for TFE - // context. - auto* config = TF_CreateConfig( - /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 10); - TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); - if (!status->status.ok()) { - CHECK(!config); - TFE_DeleteContextOptions(opts); - return nullptr; - } - - auto* ctx = TFE_NewContextFromSession(opts, session, status); - TF_DeleteBuffer(config); - TFE_DeleteContextOptions(opts); - return ctx; -} - -// TODO: retrieve the device string via TFE_ContextListDevices() -static const char DEFAULT_CPU_DEVICE[] = - "/job:localhost/replica:0/task:0/device:CPU:0"; - -static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, - int tensor_id, TF_Status* status) { - std::unique_ptr queueOp( - TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); - TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return nullptr; - // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. - TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); - TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); - auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); - TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), - shared_name.size()); - TFE_OpSetAttrString(queueOp.get(), "container", "", 0); - - // TODO: consider making this an unknown shape. - const int64_t* dims_ptr = nullptr; - int num_dims = 0; - TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, - /*num_values*/ 0, status); - if (!status->status.ok()) return nullptr; - - int num_retvals = 1; - TFE_TensorHandle* queue = nullptr; - TFE_Execute(queueOp.get(), &queue, &num_retvals, status); - if (!status->status.ok()) return nullptr; - CHECK_EQ(num_retvals, 1); - - return queue; -} - -static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, - TFE_TensorHandle* queue, TFE_TensorHandle* tensor, - TF_Status* status) { - TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); - if (!status->status.ok()) return; - std::unique_ptr op_deleter(op, TFE_DeleteOp); - TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return; - TFE_OpAddInput(op, queue, status); - if (!status->status.ok()) return; - TFE_OpAddInput(op, tensor, status); - if (!status->status.ok()) return; - TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); - TFE_OpSetAttrInt(op, "timeout_ms", -1); - - int num_retvals = 0; - TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); - if (!status->status.ok()) return; - CHECK_EQ(num_retvals, 0); -} - -static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, - TF_DataType inputType, - TFE_TensorHandle* queue, - TF_Status* status) { - TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); - if (!status->status.ok()) return nullptr; - std::unique_ptr op_deleter(op, TFE_DeleteOp); - TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return nullptr; - - TFE_OpAddInput(op, queue, status); - if (!status->status.ok()) return nullptr; - TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); - TFE_OpSetAttrInt(op, "timeout_ms", -1); - TFE_TensorHandle* ret; - int num_retvals = 1; - TFE_Execute(op, &ret, &num_retvals, status); - if (!status->status.ok()) return nullptr; - CHECK_EQ(num_retvals, 1); - return ret; -} - -TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, - TF_DataType inputType, - TF_Status* status) { - assert(session); - VLOG(1) << "Dequeuing data tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - auto* ret = createTFEDequeue(ctx, inputType, queue, status); - return ret; -} - -TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, - TF_DataType inputType, - TF_Status* status) { - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - auto* ret = createTFEDequeue(ctx, inputType, queue, status); - - return ret; -} - -void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, - TFE_TensorHandle* tensor, TF_Status* status) { - assert(session); - VLOG(1) << "Enqueuing data tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TF_DataType inputType = TFE_TensorHandleDataType(tensor); - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, inputType, queue, tensor, status); -} - -void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status) { - VLOG(1) << "Enqueuing data tensor with id " << tensor_id; - - TF_DataType inputType = TFE_TensorHandleDataType(tensor); - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, inputType, queue, tensor, status); -} - -void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, - TFE_TensorHandle* tensor, TF_Status* status) { - VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); -} - -TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, - TF_Status* status) { - VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr - queue_deleter(queue, TFE_DeleteTensorHandle); - - return createTFEDequeue(ctx, TF_VARIANT, queue, status); -} - void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } @@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, const TF_DataType* values, int num_values) { auto iter = builder->attr_names.insert(attr_name).first; - builder->Set( - (*iter).c_str(), - tensorflow::gtl::ArraySlice( - reinterpret_cast(values), num_values)); + builder->Set(*iter, tensorflow::gtl::ArraySlice( + reinterpret_cast(values), + num_values)); } void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder, diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 551a45d92c4..d0ffbf125fb 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, // Create a serialized tensorflow.ServerDef proto. TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); -// TODO: remove this API in favor of the next one. -TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( - const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); - -// Creates from `session` a new eager context to run a graph function or -// sends/recvs, so that these concurrent TFE executions can share (via -// `session` and its associated device mgr) the same set of fifo queue resource -// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and -// graph function execution can access the same fifo queue resource handles -// (associated with devices managed by the device manager, which can be obtained -// from `session`). -// -// TODO: Remove this function once we migrate away from using session. -TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession( - TF_Session* session, TF_Status* status); - -// TODO: Retire this API in favor of the next one. -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor( - TF_Session* session, int tensor_id, TF_DataType inputType, - TF_Status* status); - -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx( - TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status); - -TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session, - int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status); - -TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx( - TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor, - TF_Status* status); - -// TODO: consider folding the 2 APIs below into the ones above. -TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, - int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status); - -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( - TF_Session* session, int tensor_id, TF_Status* status); - TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3d3fc7065a4..407acfe1ca9 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -16,7 +16,6 @@ load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") package( licenses = ["notice"], # Apache 2.0 @@ -36,7 +35,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":context_interface", @@ -145,6 +144,24 @@ cc_library( ], ) +cc_library( + name = "c_api_unified_internal", + hdrs = [ + "c_api_unified_experimental_internal.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:tf_status", + "//tensorflow/core/platform:casts", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "tensor_handle_interface", hdrs = ["tensor_handle_interface.h"], @@ -320,6 +337,7 @@ tf_cuda_cc_test( tags = [ "noguitar", # TODO(b/155445984): flaky #"guitar", + "notap", # TODO(b/156981931): flaky "multi_gpu", ], deps = [ @@ -350,6 +368,38 @@ tf_cuda_cc_test( # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], extra_copts = tfe_xla_copts(), + tags = [ + "noasan", # leaks gRPC server instances + ], + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_internal", + ":c_api_test_util", + ":tfe_tensorhandle_internal", + "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:function_optimization_registry", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cuda_cc_test( + name = "c_api_distributed_test", + size = "small", + srcs = [ + "c_api_distributed_test.cc", + ], + # TODO(b/136478427): Figure out how to correctly shut the server down + args = ["--heap_check=local"], + extra_copts = tfe_xla_copts(), tags = ["noasan"], # leaks gRPC server instances deps = [ ":c_api", @@ -358,10 +408,13 @@ tf_cuda_cc_test( ":c_api_test_util", ":tfe_tensorhandle_internal", "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:function_optimization_registry", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "@com_google_absl//absl/strings", @@ -413,7 +466,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api", @@ -449,6 +502,8 @@ tf_cuda_library( "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_status_helper", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -506,6 +561,7 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -609,7 +665,6 @@ filegroup( ], exclude = [ "c_api_experimental.cc", - "*c_api_tfrt*", "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 9651a47d6ac..5a39c17e1d9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" #ifdef PLATFORM_GOOGLE -#include "tensorflow/c/eager/c_api_tfrt.h" +#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #endif #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) { } #if !defined(IS_MOBILE_PLATFORM) +bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context, + const tensorflow::ServerDef& server_def) { + if (server_def.job_name() != context->HostCPU()->parsed_name().job) { + return false; + } + return server_def.default_session_config().SerializeAsString() == + context->session_options().config.SerializeAsString(); +} + tensorflow::Status AddRemoteDevicesToMgr( const std::vector& added_remote_workers, tensorflow::WorkerCacheInterface* worker_cache, @@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server; if (reset_context) { - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); + const tensorflow::DeviceMgr* device_mgr = + AreLocalDevicesCompatible(context, server_def) + ? context->local_device_mgr() + : nullptr; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions( + server_def, {device_mgr}, &new_server)); grpc_server = dynamic_cast(new_server.get()); LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); + ListRemoteWorkers(new_server.get(), worker_name, &remote_workers)); } else { LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name, &curr_remote_workers)); @@ -727,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::GetDefaultCustomKernelCreator())); } -TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, - TF_Session* sess, TF_Status* status) { - const tensorflow::DeviceMgr* device_mgr = nullptr; - status->status = sess->session->LocalDeviceManager(&device_mgr); - if (!status->status.ok()) return nullptr; - tensorflow::Rendezvous* r = - new tensorflow::IntraProcessRendezvous(device_mgr); - - return tensorflow::wrap(new tensorflow::EagerContext( - opts->session_options.options, - static_cast( - opts->device_placement_policy), - static_cast(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr, - /*device_mgr_owned*/ false, r, - tensorflow::GetDefaultCustomKernelCreator())); -} - void TFE_DeleteContext(TFE_Context* ctx) { if (ctx == nullptr) { return; @@ -899,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->SyncExecutors(); + status->status = tensorflow::unwrap(ctx)->AsyncWait(); #endif // !IS_MOBILE_PLATFORM } @@ -924,7 +918,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( context->GetDevicePlacementPolicy()); } -TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { +TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 070b3a9bb60..5afe3047dd7 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, // placed in memory of different devices or remote address spaces. typedef struct TFE_TensorHandle TFE_TensorHandle; -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status); // Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index 8f585d6f02c..7a604950a63 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -30,24 +30,11 @@ namespace { using ::tensorflow::string; -tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { - tensorflow::ServerDef server_def; - server_def.set_protocol("grpc"); - server_def.set_job_name(job_name); - server_def.set_task_index(0); - tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); - tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name(job_name); - for (int i = 0; i < num_tasks; i++) { - int port = tensorflow::testing::PickUnusedPortOrDie(); - job_def->mutable_tasks()->insert( - {i, tensorflow::strings::StrCat("localhost:", port)}); - } - return server_def; -} - -tensorflow::ServerDef GetServerDef(int num_tasks) { - return GetServerDef("localhost", num_tasks); +void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) { + tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->at(task_index) = + tensorflow::strings::StrCat("localhost:", port); } void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, @@ -101,6 +88,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, TF_DeleteStatus(status); } +// Read the value of variable `var` and save it into `out_value`. +void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var, + TFE_TensorHandle** out_value) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 1; + TFE_Execute(op, out_value, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + TF_DeleteStatus(status); +} + void TestRemoteExecuteChangeServerDef(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -243,6 +246,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) { TestRemoteExecuteUpdateServerDef(true); } +void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + + TFE_TensorHandle* value_handle = nullptr; + ReadVariable(ctx, var_handle1, &value_handle); + CheckTFE_TensorHandleHasFloats(value_handle, {2}); + TFE_DeleteTensorHandle(value_handle); + + // Start a new worker to replace task:1 + ReplaceTaskInServerDef(&server_def, 1); + server_def.set_task_index(1); + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + // Update server def to replace the remote device with the device info on the + // new worker (different incarnation ID). + server_def.set_task_index(0); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // The device of var_handle0 is local device which is the same before and + // after cluster update. Remove resource with valid device should succeed. + TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle0, status); + TFE_OpSetDevice(op, dev0_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + // The device of var_handle1 is remote device, which was replaced during + // cluster update. Removing resource with invalid device should fail + // gracefully (i.e., with error status) instead of crashing with segfaults. + op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle1, status); + TFE_OpSetDevice(op, dev1_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) { + TestRemoteExecuteUpdateServerDefResourceAccess(false); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) { + TestRemoteExecuteUpdateServerDefResourceAccess(true); +} + void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { // Fail fast on GetStatus requests so we can get errors instead of timeout // when updating cluster with non-exsitent worker @@ -282,6 +381,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( {2, tensorflow::strings::StrCat("localhost:", port)}); + server_def.set_task_index(0); string serialized_update = server_def.SerializeAsString(); TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), serialized_update.size(), status); @@ -310,4 +410,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { TestRemoteExecuteUpdateServerDefWithFailures(true); } +void TestConnectToCluster(bool keep_localhost_for_first_connect) { + // Fail fast on GetStatus requests so we can get errors instead of timeout + // when updating cluster with non-exsitent worker + tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1); + + const string first_name = + keep_localhost_for_first_connect ? "localhost" : "abc"; + tensorflow::ServerDef server_def = GetServerDef(first_name, 1); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + + tensorflow::Status status2; + EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name); + + // Rename local device + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev1_name = + absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0"); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name); + + // Another renaming of local device + const string second_name = "def"; + server_def.set_job_name(second_name); + server_def.mutable_cluster()->mutable_job(0)->set_name(second_name); + (*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] = + absl::StrCat(second_name, ":", + tensorflow::testing::PickUnusedPortOrDie()); + + serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle2, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + TFE_DeleteTensorHandle(var_handle2); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + tensorflow::unsetenv("GRPC_FAIL_FAST"); +} + +TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); } + +TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); } + } // namespace diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc new file mode 100644 index 00000000000..65f8d3cc646 --- /dev/null +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -0,0 +1,506 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace { + +using ::tensorflow::string; + +// Add the values of three variables on three different tasks. +string AddVariablesFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'AddVariablesFunction'" + " input_arg {" + " name: 'var'" + " type: DT_RESOURCE" + " }" + " output_arg {" + " name: 'sum'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'read0'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:0/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'read1'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:1/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'read2'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:2/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add1'" + " op: 'Add'" + " input: 'read0:value:0'" + " input: 'read1:value:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add2'" + " op: 'Add'" + " input: 'add1:z:0'" + " input: 'read2:value:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'sum'" + " value: 'add2:z:0'" + " }", + &def)); + return def.SerializeAsString(); +} + +void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(op, var_handle, status); + TFE_TensorHandle* is_initialized[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(op, &is_initialized[0], &num_retvals, status); + CHECK_EQ(1, num_retvals); + TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status); + bool initialized = false; + memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(initialized, true); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(is_initialized[0]); + TFE_DeleteOp(op); + delete status; +} + +void TestFunctionWithPackedInput(const bool remote) { + tensorflow::ServerDef server_def = GetServerDef(3); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(/*enable=*/true)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + // Create one variable per task. + TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name); + TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name); + TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name); + + // Add a sync point in order to make sure that variables have been initialized + // before the function execution starts. + // TODO(b/155789951): Remove once b/155789951 is fixed. + VarIsInitialized(ctx, h1); + VarIsInitialized(ctx, h2); + + // Pack 3 variable handles into one TFE_TensorHandle. + int num_replicas = 3; + std::vector handles = {h0, h1, h2}; + TFE_TensorHandle* packed_handle = + TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE); + EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0); + EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1); + + const string composite_device_name = + "/job:localhost/replica:0/task:0/device:COMPOSITE:0"; + EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status), + composite_device_name); + EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status), + composite_device_name); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // Register and run a function which returns the sum of 3 variables. + const string function_def = AddVariablesFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(func, packed_handle, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + if (remote) { + TFE_OpSetDevice(func, task1_name, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(func, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TFE_DeleteOp(func); + TFE_DeleteTensorHandle(packed_handle); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + float sum = 0; + EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t)); + memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(sum, 6.0); + + TFE_DeleteTensorHandle(h0); + TFE_DeleteTensorHandle(h1); + TFE_DeleteTensorHandle(h2); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteExecutor(executor); + TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status); + TFE_DeleteContext(ctx); + + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} + +TEST(CAPI, TestLocalFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/false); +} + +TEST(CAPI, TestRemoteFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/true); +} + +string VariableAddFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'VariableAddFunction'" + " input_arg {" + " name: 'var0'" + " type: DT_RESOURCE" + " }" + " output_arg {" + " name: 'var0_value'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'read0'" + " op: 'ReadVariableOp'" + " input: 'var0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add'" + " op: 'Add'" + " input: 'read0:value:0'" + " input: 'read0:value:0'" + " device: '/job:localhost/task:1/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'identity'" + " op: 'Identity'" + " input: 'add:z:0'" + " device: '/job:localhost/task:0/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'var0_value'" + " value: 'identity:output:0'" + " }", + &def)); + return def.SerializeAsString(); +} + +class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { + public: + FunctionErrorInjectionPass(string error_node, string error_device) + : error_node_(error_node), error_device_(error_device) {} + tensorflow::Status Run(const tensorflow::DeviceSet& device_set, + const tensorflow::ConfigProto& config_proto, + std::unique_ptr* graph, + tensorflow::FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override { + // Inject failure to function instantiation if finding a node that contains + // the given node name (error_node_) and requested device (error_device_). + for (const auto node : graph->get()->nodes()) { + if (node->name().find(error_node_) != string::npos && + node->requested_device() == error_device_) { + return tensorflow::errors::Internal("Injected graph pass error."); + } + } + return tensorflow::Status::OK(); + } + + private: + const string error_node_; + const string error_device_; +}; + +void TestDistributedFunctionCancellation(bool inject_error) { + tensorflow::ServerDef server_def = GetServerDef(3); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + if (inject_error) { + // Inject a function optimization pass failure when it sees the 'read0' op + // having a requested device `dev2_name`. During execution: + // * task:0 processes the main function `VariableAddFunction` and places + // the read0 op on task:2 + // * task:0 partitions the main function with a subgraph containing read0 + // sent to task:2 + // * task:2 graph pass reports an error when it sees read0 with dev2_name + tensorflow::function_optimization_registration:: + FunctionOptimizationPassRegistration register_test_pass( + std::make_unique("read0", dev2_name)); + } + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle, nullptr); + + const string function_def = VariableAddFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(func, var_handle, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(func, &retvals[0], &num_retvals, status); + + if (inject_error) { + ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status); + } else { + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + float sum = 0; + ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t)); + memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + ASSERT_EQ(sum, 4.0); + } + + TFE_DeleteOp(func); + TFE_DeleteTensorHandle(var_handle); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} + +TEST(CAPI, DistributedFunctionNoError) { + TestDistributedFunctionCancellation(false); +} + +TEST(CAPI, DistributedFunctionCancelledOnError) { + TestDistributedFunctionCancellation(true); +} + +void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Use large matrices so that RPCs don't return before we get a chance + // to call TFE_DeleteContext. + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_DeleteContext(ctx); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) { + TestRemoteExecuteDeleteContextWithOutstandingRPC(false); +} + +TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) { + TestRemoteExecuteDeleteContextWithOutstandingRPC(true); +} +} // namespace diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 820650e315f..0d71b11531b 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -638,3 +639,35 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t, return tensorflow::wrap( tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor)); } + +TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, + TFE_TensorHandle** handles, + int* num_handles, + TF_Status* status) { + std::vector tensor_handles; + tensor_handles.reserve(*num_handles); + for (int i = 0; i < *num_handles; ++i) { + tensor_handles.push_back( + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i]))); + } + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + tensorflow::TensorHandle* handle = nullptr; + status->status = tensorflow::TensorHandle::CreatePackedHandle( + std::move(tensor_handles), context, &handle); + return tensorflow::wrap(handle); +} + +void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetAllowSoftPlacement(enable); +} + +void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetLogDevicePlacement(enable); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 33adce40da0..1b8efe61ee0 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -541,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor( TFE_Context* ctx, TF_Tensor* t, TF_Status* status); +// Create a packed TensorHandle with the given list of TensorHandles. +// If `handles` are on the same device, assign the same device to the packed +// handle; if `handles` are on different deivces, assign a CompositeDevice to +// it. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( + TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, + TF_Status* status); + +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 0f988b1456d..94c32cf3f30 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -19,37 +19,22 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" namespace { using ::tensorflow::string; -tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { - tensorflow::ServerDef server_def; - server_def.set_protocol("grpc"); - server_def.set_job_name(job_name); - server_def.set_task_index(0); - tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); - tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name(job_name); - for (int i = 0; i < num_tasks; i++) { - int port = tensorflow::testing::PickUnusedPortOrDie(); - job_def->mutable_tasks()->insert( - {i, tensorflow::strings::StrCat("localhost:", port)}); - } - return server_def; -} - -tensorflow::ServerDef GetServerDef(int num_tasks) { - return GetServerDef("localhost", num_tasks); -} - void TestRemoteExecute(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -351,74 +336,4 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { /*heavy_load_on_streaming_rpc=*/true); } -void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { - tensorflow::ServerDef server_def = GetServerDef(2); - - // This server def has the task index set to 0. - string serialized = server_def.SerializeAsString(); - - server_def.set_task_index(1); - - std::unique_ptr worker_server; - ASSERT_TRUE(tensorflow::GrpcServer::Create( - server_def, tensorflow::Env::Default(), &worker_server) - .ok()); - ASSERT_TRUE(worker_server->Start().ok()); - - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_ContextOptionsSetDevicePlacementPolicy(opts, - TFE_DEVICE_PLACEMENT_EXPLICIT); - TFE_Context* ctx = TFE_NewContext(opts, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - // Use large matrices so that RPCs don't return before we get a chance - // to call TFE_DeleteContext. - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx); - TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx); - const char remote_device_name[] = - "/job:localhost/replica:0/task:1/device:CPU:0"; - auto* h0_task1 = - TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - auto* h1_task1 = - TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); - TFE_OpSetDevice(matmul, remote_device_name, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_TensorHandle* retvals[1]; - int num_retvals = 1; - TFE_Execute(matmul, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - - TFE_DeleteTensorHandle(h0_task0); - TFE_DeleteTensorHandle(h1_task0); - TFE_DeleteTensorHandle(h0_task1); - TFE_DeleteTensorHandle(h1_task1); - TFE_DeleteTensorHandle(retvals[0]); - - TFE_DeleteOp(matmul); - - TFE_DeleteContext(ctx); - - // TODO(b/136478427): Figure out how to correctly shut the server down. - worker_server.release(); -} - -TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) { - TestRemoteExecuteDeleteContextWithOutstandingRPC(false); -} - -TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) { - TestRemoteExecuteDeleteContextWithOutstandingRPC(true); -} } // namespace diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3160cb0e585..724176505ba 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) { } BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1); -TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, - TF_Status* status) { - // Create the variable handle. - TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", "", 0); - TFE_OpSetAttrString(op, "shared_name", "", 0); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_TensorHandle* var_handle = nullptr; - int num_retvals = 1; - TFE_Execute(op, &var_handle, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(1, num_retvals); - - // Assign 'value' to it. - op = TFE_NewOp(ctx, "AssignVariableOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle, status); - - // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. - std::unique_ptr t( - TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); - memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); - - std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get(), status), - TFE_DeleteTensorHandle); - if (TF_GetCode(status) != TF_OK) return nullptr; - - TFE_OpAddInput(op, value_handle.get(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - - num_retvals = 0; - TFE_Execute(op, nullptr, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(0, num_retvals); - - return var_handle; -} - TEST(CAPI, Variables) { // Variables use resource handles, so this is really a test for resource // tensor handling. @@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); + TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); @@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); + TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); @@ -1248,6 +1203,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; + TFE_OpAddInput(op, var_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } tensorflow::testing::StopTiming(); TFE_DeleteOp(op); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index e67e17963b3..4b5ad8f50f7 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/cluster.pb.h" using tensorflow::string; @@ -133,6 +135,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) { return th; } +TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name) { + TF_Status* status = TF_NewStatus(); + // Create the variable handle. + TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op, "shape", {}, 0, status); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); + if (!device_name.empty()) { + TFE_OpSetDevice(op, device_name.c_str(), status); + } + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op, &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(1, num_retvals); + + // Assign 'value' to it. + op = TFE_NewOp(ctx, "AssignVariableOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + std::unique_ptr t( + TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); + memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); + + std::unique_ptr + value_handle(TFE_NewTensorHandle(t.get(), status), + TFE_DeleteTensorHandle); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_OpAddInput(op, value_handle.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(0, num_retvals); + + TF_DeleteStatus(status); + + return var_handle; +} + TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TF_Status* status = TF_NewStatus(); @@ -244,3 +298,23 @@ bool GetDeviceName(TFE_Context* ctx, string* device_name, TF_DeleteDeviceList(devices); return false; } + +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name(job_name); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name(job_name); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 11ae6d1181b..fcf62223f14 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" // Return a tensor handle containing a float scalar TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value); @@ -42,6 +43,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx); // Return a tensor handle containing a 3x2 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx); +// Return a variable handle referring to a variable with the given initial value +// on the given device. +TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name = ""); + // Return an add op multiplying `a` by `b`. TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); @@ -67,4 +73,11 @@ TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, const char* device_type); +// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it. +tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name, + int num_tasks); + +// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it. +tensorflow::ServerDef GetServerDef(int num_tasks); + #endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 68afffb28b4..e5030a602b3 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" @@ -26,6 +28,51 @@ using tensorflow::string; using tensorflow::internal::OutputList; using tensorflow::internal::unwrap; +namespace tensorflow { +namespace internal { +typedef absl::flat_hash_map FactoriesMap; + +static FactoriesMap& GetFactories() { + static FactoriesMap* factories = new FactoriesMap; + return *factories; +} + +static const char* default_factory = ""; + +void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { + assert((!GetFactories().count(name)) || + (GetFactories()[name] == factory) && + "Duplicate tracing factory registration"); + GetFactories()[name] = factory; +} + +void SetDefaultTracingEngine(const char* name) { default_factory = name; } + +static ExecutionContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { + auto entry = GetFactories().find(default_factory); + if (entry != GetFactories().end()) return entry->second(fn_name, s); + string msg = absl::StrCat( + "No tracing engine factory has been registered with the key '", + default_factory, "' (available: "); + // Ensure deterministic (sorted) order in the error message + std::set factories_sorted; + for (const auto& factory : GetFactories()) + factories_sorted.insert(factory.first); + const char* comma = ""; + for (const string& factory : factories_sorted) { + msg += comma + factory; + comma = ", "; + } + msg += ")"; + + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; +} + +} // end namespace internal +} // end namespace tensorflow + // ============================================================================= // Public C API entry points // @@ -36,6 +83,28 @@ using tensorflow::internal::unwrap; // // ============================================================================= +void TF_SetTracingImplementation(const char* name) { + tensorflow::internal::SetDefaultTracingEngine(name); +} + +// Creates a new TensorFlow function, it is an execution context attached to a +// given tracing context. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) { + return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s)); +} + +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList* outputs, TF_Status* s) { + auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s)); + TF_DeleteExecutionContext(ctx); + return func; +} + +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s) { + return wrap(unwrap(func)->AddParameter(dtype, s)); +} + void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); } TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { @@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) { TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { return wrap(unwrap(o)->outputs[i]); } +void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, + TF_Status* s) { + unwrap(o)->outputs.push_back(unwrap(tensor)); +} void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, TF_Status* s) { diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index be8fc64c2e1..86c59a7f625 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp; // setting functional attributes of other composite ops e.g. control flow. typedef struct TF_AbstractFunction TF_AbstractFunction; -// Creates a context for tracing the execution of operations into a function. -TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s); +// This allows the client to swap the implementation of the tracing engine. +// Any future call to TF_CreateFunction will use the implementation defined +// here. +void TF_SetTracingImplementation(const char* name); + +// Creates a new TensorFlow function. A Function is an execution context, and as +// such it can trace operations through TF_ExecuteOperation. After completing +// tracing, a function can be obtained by TF_FinalizeFunction. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status); // Creates a context for eager execution of operations. TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, TF_Status* s); - void TF_DeleteExecutionContext(TF_ExecutionContext*); +// Add a new parameter to a TensorFlow Function. +// TODO(aminim): what about shape? +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s); + // Create an operation suitable to use with the provided context. The operation // requires its type (e.g. "AddV2") to be set independently. TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx); @@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, void TF_DeleteAbstractTensor(TF_AbstractTensor*); // TF_OutputList holds the list of TF_AbstractTensor that results from executing -// an operation. -// It just lets us not specify the number of outputs of an operation -// beforehand. This forces a memory allocation in the runtime, which is bad, but -// it allows for generic code. -// TODO(aminim): the description above isn't clear with respect to -// TF_OutputListNumOutputs and the current eager implementation which requires -// the number of outputs to be set by the client. +// an operation, or provided to create a function. +// When executing an operation in an eager context, the expected number of +// outputs must be set beforehand with `TF_OutputListSetNumOutputs`. typedef struct TF_OutputList TF_OutputList; TF_OutputList* TF_NewOutputList(); void TF_DeleteOutputList(TF_OutputList* o); -void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*); +// Prepare tracing to the expected number of output for an operation. +void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*); +// Return the number of outputs in the list. int TF_OutputListNumOutputs(TF_OutputList* o); +// Return the `i`th output in the list. TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i); +// Append a tensor at the end of the output list, growing its size by one. +void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, + TF_Status*); // TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe // capture some inputs and then add a node in the graph. The output tensors are @@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_ExecutionContext* ctx, TF_Status* s); // Creates a new TF_AbstractFunction from the current tracing states in the -// context. The returned TF_GraphToFunction must be deleted by the client. +// context. The provided `ctx` is consumed by this API call and deleted. +// The returned TF_AbstractFunction must be deleted by the client, // TODO(aminim): clarify the contract on the state of the context after this // call. -TF_AbstractFunction* TF_ExecutionContextToFunction( - const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const TF_AbstractTensor* inputs, int num_outputs, - const TF_AbstractTensor* outputs, TF_Status* status); +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList*, TF_Status*); void TF_DeleteAbstractFunction(TF_AbstractFunction*); diff --git a/tensorflow/c/eager/c_api_unified_experimental_eager.cc b/tensorflow/c/eager/c_api_unified_experimental_eager.cc index 820c61445fb..cf8cf845834 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_eager.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_eager.cc @@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext { } } + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't add function parameter on an eager context."); + return nullptr; + } + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't use finalize function on an eager context."); + return nullptr; + } + void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override { auto* func = afunc->GetTfFunction(s); if (!func) { diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 36f8353894b..dd5a95b3526 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" @@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction { static constexpr AbstractFunctionKind kKind = kGraphFunc; }; -// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e. -// adding them to the graph. +// GraphContext wraps a TF_Graph modeling a single function and manages the +// "execution" of operation, i.e. adding them to the function. class GraphContext : public ExecutionContext { public: - GraphContext() - : ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {} + explicit GraphContext(const char* name) + : ExecutionContext(kKind), + graph_(new TF_Graph(), TF_DeleteGraph), + name_(name) {} AbstractOp* CreateOperation() override { // TODO(srbs): Should the lifetime of this op be tied to the context. @@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext { return; } auto* tf_opdesc = graph_op->op_.release(); + if (tf_opdesc == nullptr) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete."); + return; + } for (int i = 0; i < num_inputs; ++i) { auto* graph_tensor = dyncast(inputs[i]); if (!graph_tensor) { @@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext { } } - TF_Function* ToFunction(const char* fn_name, int num_inputs, - const GraphTensor* inputs, int num_outputs, - const GraphTensor* outputs, TF_Status* status) const { - std::vector graph_inputs; - graph_inputs.resize(num_inputs); + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_OperationDescription* opdesc = + TF_NewOperation(graph_.get(), "Placeholder", + absl::StrCat("_input_", inputs_.size()).c_str()); + TF_SetAttrType(opdesc, "dtype", dtype); + auto* operation = TF_FinishOperation(opdesc, s); + if (!s->status.ok()) return nullptr; + + inputs_.push_back(TF_Output{operation, 0}); + return new GraphTensor(inputs_.back(), this); + } + + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + std::unique_ptr func(new GraphFunction); std::vector graph_outputs; - graph_outputs.resize(num_outputs); - for (int i = 0; i < num_inputs; i++) { - graph_inputs[i] = inputs[i].output; - } - for (int i = 0; i < num_outputs; i++) { - graph_outputs[i] = outputs[i].output; + graph_outputs.reserve(outputs->outputs.size()); + for (AbstractTensor* abstract_output : outputs->outputs) { + GraphTensor* output = dyncast(abstract_output); + if (!output) { + TF_SetStatus(s, TF_UNIMPLEMENTED, + "Returning a non-graph tensor from a function has not " + "been implemented yet."); + return nullptr; + } + graph_outputs.push_back(output->output); } - return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr, - graph_inputs.size(), graph_inputs.data(), - graph_outputs.size(), graph_outputs.data(), - nullptr, nullptr, fn_name, status); + func->func = TF_GraphToFunction( + graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), + graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); + if (TF_GetCode(s) != TF_OK) return nullptr; + return func.release(); } void RegisterFunction(AbstractFunction* func, TF_Status* s) override { @@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext { private: std::unique_ptr graph_; + std::vector inputs_; + const char* name_; }; -// Helper that converts the graph currently held in the context into a function. -static AbstractFunction* ExecutionContextToFunction( - const ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const AbstractTensor* inputs, int num_outputs, - const AbstractTensor* outputs, TF_Status* status) { - auto* graph_ctx = dyncast(fn_body); - if (graph_ctx == nullptr) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - "fn_body is not a TF_GraphContext."); - return nullptr; - } - auto* graph_inputs = dyncast(inputs); - if (!graph_inputs) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors."); - return nullptr; - } - auto* graph_outputs = dyncast(outputs); - if (!graph_outputs) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors."); - return nullptr; - } - GraphFunction* func = new GraphFunction; - func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs, - num_outputs, graph_outputs, status); - return func; +static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) { + return new GraphContext(name); } +// Register the tracing implemented in this file as the default tracing engine. +static bool register_tracing = [] { + RegisterTracingEngineFactory("graphdef", GraphTracingFactory); + SetDefaultTracingEngine("graphdef"); + return true; +}(); + } // namespace internal } // namespace tensorflow - -// ============================================================================= -// Public C API entry points -// These are only the entry points specific to the Graph API. -// ============================================================================= - -using tensorflow::internal::unwrap; - -TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) { - return wrap(new tensorflow::internal::GraphContext()); -} - -TF_AbstractFunction* TF_ExecutionContextToFunction( - const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const TF_AbstractTensor* inputs, int num_outputs, - const TF_AbstractTensor* outputs, TF_Status* status) { - return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs, - unwrap(inputs), num_outputs, - unwrap(outputs), status)); -} diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index ab085a20ff0..8fc696f0f2f 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace internal { @@ -57,7 +58,7 @@ T* dyncast(S source) { // GraphContext and vice-versa). class AbstractTensor { protected: - enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor }; + enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor }; explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {} public: @@ -100,7 +101,7 @@ class AbstractFunction { // on a given context, with the same or different input tensors. class AbstractOp { protected: - enum AbstractOpKind { kGraphOp, kEagerOp }; + enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp }; explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {} public: @@ -128,7 +129,7 @@ class AbstractOp { // eager implementation or to a graph implementation. struct ExecutionContext { protected: - enum ExecutionContextKind { kGraphContext, kEagerContext }; + enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext }; explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {} public: @@ -148,6 +149,17 @@ struct ExecutionContext { // Creates an empty AbstractOperation suitable to use with this context. virtual AbstractOp* CreateOperation() = 0; + // Add a function parameter and return the corresponding tensor. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0; + + // Finalize this context and make a function out of it. The context is in a + // invalid state after this call and must be destroyed. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0; + // Registers a functions with this context, after this the function is // available to be called/referenced by its name in this context. virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0; @@ -156,6 +168,11 @@ struct ExecutionContext { const ExecutionContextKind k; }; +typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*); +void SetDefaultTracingEngine(const char* name); +void RegisterTracingEngineFactory(const ::tensorflow::string& name, + FactoryFunction factory); + // Create utilities to wrap/unwrap: this convert from the C opaque types to the // C++ implementation, and back. #define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \ diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 170b82333d8..24d170f2f99 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -29,7 +29,12 @@ using tensorflow::string; namespace tensorflow { namespace { -TEST(UnifedCAPI, TestBasicEager) { +class UnifiedCAPI : public ::testing::TestWithParam { + protected: + void SetUp() override { TF_SetTracingImplementation(GetParam()); } +}; + +TEST_P(UnifiedCAPI, TestBasicEager) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -81,33 +86,18 @@ TEST(UnifedCAPI, TestBasicEager) { TF_DeleteExecutionContext(ctx); } -TEST(UnifedCAPI, TestBasicGraph) { +TEST_P(UnifiedCAPI, TestBasicGraph) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + // Start a new function / execution context. + string fn_name = "double"; + TF_ExecutionContext* graph_ctx = + TF_CreateFunction(fn_name.c_str(), status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - // Add a placeholder to the graph. - auto* placeholder_op = TF_NewAbstractOp(graph_ctx); - TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); + auto* placeholder_t = + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build inputs and outputs. - TF_OutputList* placeholder_outputs = TF_NewOutputList(); - - // Execute. - TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs, - graph_ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs)); - TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0); - - // Delete placeholder op. - TF_DeleteAbstractOp(placeholder_op); // Build an abstract operation. auto* add_op = TF_NewAbstractOp(graph_ctx); @@ -123,16 +113,13 @@ TEST(UnifedCAPI, TestBasicGraph) { // Execute. TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0); // Clean up operation and inputs. TF_DeleteAbstractOp(add_op); - string fn_name = "double"; - TF_AbstractFunction* func = TF_ExecutionContextToFunction( - graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get()); - TF_DeleteAbstractTensor(placeholder_t); - TF_DeleteAbstractTensor(output_t); + TF_AbstractFunction* func = + TF_FinalizeFunction(graph_ctx, add_outputs, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build eager context. TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -173,18 +160,161 @@ TEST(UnifedCAPI, TestBasicGraph) { ASSERT_EQ(*f_value, 4.0); TF_DeleteOutputList(add_outputs); - TF_DeleteOutputList(placeholder_outputs); TF_DeleteAbstractOp(fn_op); TF_DeleteAbstractTensor(input_t); TF_DeleteAbstractTensor(final_result); TF_DeleteTensor(f_t); TF_DeleteAbstractFunction(func); - TF_DeleteExecutionContext(graph_ctx); TF_DeleteExecutionContext(eager_execution_ctx); } -TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { +TEST_P(UnifiedCAPI, TestMultiOutputGraph) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + + // Start a new function / execution context. + string fn_name = "two_adds"; + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Create a first "Add" computing `arg0 + arg1`. + TF_AbstractTensor* add_output1; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add1", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg0, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + // Extract the resulting tensor. + add_output1 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // Same with a second "Add" computing `arg1 + arg1`. + TF_AbstractTensor* add_output2; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add2", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg1, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + // Extract the resulting tensor. + add_output2 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // Finalize the function by providing the returned values. + TF_AbstractFunction* func; + { + // We want to return the output of both add operations, create a new list + // and populate it. + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListPushBack(func_outputs, add_output1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_OutputListPushBack(func_outputs, add_output2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + func = TF_FinalizeFunction(graph_ctx, func_outputs, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteOutputList(func_outputs); + } + + /** + * We traced so far this function: + * + * def two_adds(a, b): + * my_add1 = a + b + * my_add2 = b + b + * return my_add1, my_add2 + * + * Now we will execute this function with an eager context: + * + * output1, output2 = two_adds(2.0, 3.0) + * + * and check that we got 5.0 and 6.0 as results. + */ + + // Build eager context. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* eager_execution_ctx = + TF_NewEagerExecutionContext(opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TFE_DeleteContextOptions(opts); + + TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build the abstract op to run the function. + TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); + TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build two abstract input tensors as function arguments. + std::vector func_args; + { + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(eager_execution_ctx); + TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + input_eager = TestScalarTensorHandle(eager_ctx, 3.0f); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + } + + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(func_outputs, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs, + eager_execution_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(fn_op); + for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t); + + ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs)); + float results[2]; + for (int idx = 0; idx < 2; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + results[idx] = *static_cast(TF_TensorData(f_t)); + TF_DeleteTensor(f_t); + } + ASSERT_EQ(results[0], 5.0); + ASSERT_EQ(results[1], 6.0); + + for (int idx = 0; idx < 2; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TF_DeleteAbstractTensor(result); + } + TF_DeleteOutputList(func_outputs); + TF_DeleteExecutionContext(eager_execution_ctx); + TF_DeleteAbstractFunction(func); +} + +TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -192,18 +322,15 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContextOptions(opts); - TF_AbstractFunction* func = TF_ExecutionContextToFunction( - ctx, nullptr, 0, nullptr, 0, nullptr, status.get()); + TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get()); ASSERT_EQ(nullptr, func); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); - - TF_DeleteExecutionContext(ctx); } -TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -221,10 +348,10 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -242,7 +369,7 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) { +TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { // Build an Eager context. std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -272,7 +399,8 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build a Graph context. - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Execute eager op using graph context. @@ -288,10 +416,11 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) { +TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -348,5 +477,8 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) { TF_DeleteExecutionContext(eager_execution_ctx); } +INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, + ::testing::Values("graphdef", "mlir")); + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index 9377bf0be12..2861fa43b66 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -59,6 +59,20 @@ class AbstractContextInterface { virtual AbstractTensorInterface* CreateTensor( DataType dtype, absl::Span dim_sizes) = 0; + typedef void (*MemoryReleaser)(void* data, size_t len, void* arg); + + // Create a tensor instance from the given data buffer and description. + // `memory_releaser` will be called on destruction, and it's responsible for + // cleaning up the underlying buffer. `convert_string` indicates whether it + // has to handle tstring conversion. Expected to be removed once tstring + // migration is done. + virtual AbstractTensorInterface* CreateTensor(DataType dtype, + const int64_t* dims, + int num_dims, void* data, + size_t len, bool convert_string, + MemoryReleaser memory_releaser, + void* memory_releaser_arg) = 0; + // Create a handle to wrap and manage a Tensor virtual AbstractTensorHandleInterface* CreateLocalHandle( AbstractTensorInterface* t) = 0; @@ -87,6 +101,9 @@ class AbstractContextInterface { // Destroy the step resource container for a training step. virtual void EndStep() = 0; + // Block until all pending nodes are finished. + virtual Status AsyncWait() = 0; + protected: virtual ~AbstractContextInterface() {} }; diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index f4dbcc6cead..6fce918aab1 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -27,6 +27,7 @@ cc_library( name = "parallel_device", srcs = [":sources"], hdrs = [":headers"], + visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/c:c_api", "//tensorflow/c/eager:c_api", @@ -38,11 +39,30 @@ cc_library( ], ) +cc_library( + name = "parallel_device_testlib", + testonly = 1, + srcs = ["parallel_device_testlib.cc"], + hdrs = ["parallel_device_testlib.h"], + deps = [ + ":parallel_device", + ":parallel_device_ops", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "parallel_device_test", srcs = ["parallel_device_test.cc"], deps = [ ":parallel_device", + ":parallel_device_ops", + ":parallel_device_testlib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", @@ -52,3 +72,40 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "parallel_device_remote_test", + srcs = ["parallel_device_remote_test.cc"], + # TODO(b/136478427): Enable global heap checking when servers shut down + # cleanly. + args = ["--heap_check=local"], + deps = [ + ":parallel_device", + ":parallel_device_ops", + ":parallel_device_testlib", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + ], +) + +# Note: ParallelDevice-specific ops are experimental and not currently linked in +# to TensorFlow by default, just used in a few tests. +filegroup( + name = "parallel_device_ops_srcs", + srcs = ["parallel_device_ops.cc"], + visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"], +) + +cc_library( + name = "parallel_device_ops", + srcs = [":parallel_device_ops_srcs"], + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/core:framework"], + alwayslink = 1, +) diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index e6846809fcf..75d188d0c45 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -92,6 +92,10 @@ class ParallelDevice { TFE_TensorHandle* tensor, TF_Status* status) const; + // A parallel tensor with scalar integers numbering component devices. + std::unique_ptr DeviceIDs(TFE_Context* context, + TF_Status* status) const; + // Takes a description of a single operation being executed on the // ParallelDevice, and in turn runs one operation per component device with // its corresponding inputs from the input ParallelTensors (or @@ -208,6 +212,46 @@ std::unique_ptr ParallelDevice::CopyToParallelDevice( status); } +std::unique_ptr ParallelDevice::DeviceIDs( + TFE_Context* context, TF_Status* status) const { + // TODO(allenl): We could cache DeviceIDs (keyed by context). + std::vector components; + components.reserve(underlying_devices_.size()); + for (int device_index = 0; device_index < underlying_devices_.size(); + ++device_index) { + int64_t* device_id = new int64_t; + *device_id = device_index; + std::unique_ptr tensor( + TF_NewTensor( + TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id, + sizeof(int64_t), + [](void* data, size_t, void* arg) { + delete reinterpret_cast(data); + }, + nullptr), + TF_DeleteTensor); + // TODO(allenl): Here and when executing regular operations, we could hold + // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing + // device names repeatedly. + OpPtr const_op(TFE_NewOp(context, "Const", status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64); + TFE_TensorHandle* device_handle; + int num_outputs = 1; + TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + components.emplace_back(device_handle); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + return ParallelTensor::FromTensorHandles(*this, std::move(components), + status); +} + absl::optional> ParallelDevice::Execute( TFE_Context* context, std::vector inputs, const char* operation_name, const TFE_OpAttrs* attributes, @@ -275,6 +319,11 @@ absl::optional> ParallelDevice::Execute( std::vector outputs; outputs.reserve(t->num_tensors()); for (int i = 0; i < t->num_tensors(); ++i) { + // TODO(b/157523095): Syncing the executor here shouldn't be + // necessary. Currently async+remote is missing cross-executor + // coordination. + TFE_ExecutorWaitForAllPendingNodes(executors_[i].get(), status); + if (TF_GetCode(status) != TF_OK) return result; TensorHandlePtr this_output( TFE_TensorHandleCopySharingTensor(t->tensor(i), status)); outputs.emplace_back(std::move(this_output)); @@ -282,6 +331,13 @@ absl::optional> ParallelDevice::Execute( } result.emplace(std::move(outputs)); return result; + } else if (operation_name == std::string("DeviceID")) { + std::vector result_content; + result_content.reserve(1); + result_content.push_back(DeviceIDs(context, status)); + if (TF_GetCode(status) != TF_OK) return result; + result.emplace(std::move(result_content)); + return result; } absl::optional>> maybe_parallel_results( diff --git a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc new file mode 100644 index 00000000000..1decffca047 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +// TODO(allenl): Figure out if we need this op, and if so whether we should move +// it to core TF. Right now the eager C API does some checking of op +// registrations before calling into custom devices, but we may be able to avoid +// that. +REGISTER_OP("DeviceID") + .Output("device_id: int64") + .SetIsStateful() + .SetShapeFn(tensorflow::shape_inference::ScalarShape); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc new file mode 100644 index 00000000000..32a4b440d25 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device.h" +#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/platform/test.h" + +tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name(job_name); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name(job_name); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost", ":", port)}); + } + return server_def; +} + +TEST(PARALLEL_DEVICE, TestRemoteBasic) { + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + tensorflow::ServerDef server_def = GetServerDef("worker", 3); + + // This server def has the task index set to 0. + std::string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TFE_ContextSetServerDef(context.get(), 0, serialized.data(), + serialized.size(), status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + BasicTestsForTwoDevices(context.get(), + "/job:worker/replica:0/task:1/device:CPU:0", + "/job:worker/replica:0/task:2/device:CPU:0"); + + worker_server1.release(); + worker_server2.release(); +} + +TEST(PARALLEL_DEVICE, TestAsyncCopyOff) { + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + tensorflow::ServerDef server_def = GetServerDef("worker", 3); + + // This server def has the task index set to 0. + std::string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TFE_ContextSetServerDef(context.get(), 0, serialized.data(), + serialized.size(), status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0"; + const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0"; + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::array underlying_devices{first_device, second_device}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TensorHandlePtr value_one(FloatTensorHandle(3., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TensorHandlePtr value_two(FloatTensorHandle(-2., status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array in_components{value_one.get(), + value_two.get()}; + TensorHandlePtr combined_value = CreatePerDeviceValues( + context.get(), in_components, device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Loop to make synchronization failures more deterministic + for (int i = 0; i < 100; ++i) { + TensorHandlePtr multiply_result( + Multiply(context.get(), combined_value.get(), combined_value.get(), + status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array out_components; + ExtractPerDeviceValues(context.get(), multiply_result.get(), + &out_components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(out_components[0].get(), 9.); + ExpectScalarEq(out_components[1].get(), 4.); + } + + worker_server1.release(); + worker_server2.release(); +} diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index 9b0613b0391..d9784ac9fa6 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" #include "tensorflow/core/platform/test.h" // NOTE(allenl): These tests currently go through TFE_Execute and so are @@ -28,363 +29,6 @@ limitations under the License. // correspond fairly well to the implementation, but testing the C++ directly is // another option. -// Functor for making unique_ptr to TFE_TensorHandle slightly more -// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second -// template argument requires passing a function pointer to -// TFE_DeleteTensorHandle when constructing the unique_ptr. -class TensorHandleDeleter { - public: - void operator()(TFE_TensorHandle* to_delete) { - TFE_DeleteTensorHandle(to_delete); - } -}; - -using TensorHandlePtr = std::unique_ptr; - -// A helper for performing common operations on variables. A much more -// restricted stand-in for tf.Variable in Python. -class Variable { - public: - // Construct a Variable from a resource-dtype TFE_TensorHandle and an - // indication of the dtype of the variable's value. - // - // Note that creating this resource-dtype handle can fail, so `Create` is a - // separate static method which returns a status. - Variable(TFE_TensorHandle* handle, TF_DataType type) - : handle_(handle), type_(type) {} - - // Helper for constructing a resource handle and wrapping it in a `Variable` - // object. - static Variable* Create(TFE_Context* context, TF_DataType type, - const int64_t* dims, const int num_dims, - const char* device, TF_Status* status); - // Dereferences the backing buffer for the variable. Note that since this can - // fail (it runs operations), it must be called explicitly and the resulting - // `status` checked. - void Destroy(TFE_Context* context, TF_Status* status); - - // Reads from the variable. - TensorHandlePtr Read(TFE_Context* context, TF_Status* status); - // Assigns a new value to the variable. - void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status); - // Adds `value` to the existing value of the variable. - void AssignAdd(TFE_Context* context, TFE_TensorHandle* value, - TF_Status* status); - - private: - // Helper for running any single-argument assignment ops (Assign, AssignAdd, - // AssignSub, ...). - void GeneralAssignment(const char* op_name, TFE_Context* context, - TFE_TensorHandle* value, TF_Status* status); - - // The a handle for the resource-dtype tensor pointing to the variable's - // buffer. - TFE_TensorHandle* handle_; - // The dtype of the variable's buffer (input dtype for assignments, output - // dtype of read operations). - TF_DataType type_; -}; - -Variable* Variable::Create(TFE_Context* context, TF_DataType type, - const int64_t* dims, const int num_dims, - const char* device, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op.get(), "dtype", type); - TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status); - TFE_OpSetAttrString(op.get(), "container", "", 0); - // Use the special GUID for no buffer sharing - // - // TODO(allenl): Should we provide a better API for this? AFAIK this is the - // only reasonable way to make variables with no aliasing using the eager C - // API. - std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; - TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(), - no_sharing.length()); - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_TensorHandle* var_handle = nullptr; - int num_retvals = 1; - TFE_Execute(op.get(), &var_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return new Variable(var_handle, type); -} - -void Variable::Destroy(TFE_Context* context, TF_Status* status) { - // Free the backing buffer for the variable. - std::unique_ptr op( - TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpAddInput(op.get(), handle_, status); - if (TF_GetCode(status) != TF_OK) return; - const char* device = TFE_TensorHandleDeviceName(handle_, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return; - int num_retvals = 0; - TFE_Execute(op.get(), nullptr, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return; - // Delete the variable handle itself. - TFE_DeleteTensorHandle(handle_); -} - -TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpAddInput(op.get(), handle_, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - const char* device = TFE_TensorHandleDeviceName(handle_, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op.get(), "dtype", type_); - int num_retvals = 1; - TFE_TensorHandle* var_value = nullptr; - TFE_Execute(op.get(), &var_value, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return TensorHandlePtr(var_value); -} - -void Variable::GeneralAssignment(const char* op_name, TFE_Context* context, - TFE_TensorHandle* value, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, op_name, status), &TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetAttrType(op.get(), "dtype", type_); - TFE_OpAddInput(op.get(), handle_, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpAddInput(op.get(), value, status); - if (TF_GetCode(status) != TF_OK) return; - const char* device = TFE_TensorHandleDeviceName(handle_, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetDevice(op.get(), device, status); - - int num_retvals = 0; - TFE_Execute(op.get(), nullptr, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return; -} - -void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value, - TF_Status* status) { - GeneralAssignment("AssignAddVariableOp", context, value, status); -} - -void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value, - TF_Status* status) { - GeneralAssignment("AssignVariableOp", context, value, status); -} - -// Passed to `TF_NewTensor` to indicate how an array of floats should be -// deleted. -static void FloatDeallocator(void* data, size_t, void* arg) { - delete[] static_cast(data); -} - -// Creates a TFE_TensorHandle with value `v`. -TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) { - const int num_bytes = sizeof(float); - float* values = new float[1]; - values[0] = v; - std::unique_ptr tensor( - TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator, - nullptr), - TF_DeleteTensor); - return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); -} - -// Creates a rank-one TFE_TensorHandle with value `v`. -TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, - TF_Status* status) { - const int num_bytes = v.size() * sizeof(float); - float* values = new float[v.size()]; - memcpy(values, v.data(), num_bytes); - int64_t dims = v.size(); - std::unique_ptr tensor( - TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes, - &FloatDeallocator, nullptr), - TF_DeleteTensor); - return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); -} - -// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. -template -void ExtractPerDeviceValues( - TFE_Context* context, TFE_TensorHandle* input, - std::array* components, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas); - TFE_OpAddInput(op.get(), input, status); - if (TF_GetCode(status) != TF_OK) return; - const char* device = TFE_TensorHandleDeviceName(input, status); - if (TF_GetCode(status) != TF_OK) return; - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return; - - TFE_TensorHandle* result_handles[num_replicas]; - int num_retvals = num_replicas; - TFE_Execute(op.get(), result_handles, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return; - for (int i = 0; i < num_replicas; ++i) { - (*components)[i].reset(result_handles[i]); - } -} - -// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle. -template -TensorHandlePtr CreatePerDeviceValues( - TFE_Context* context, - const std::array& components, - const char* device, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrInt(op.get(), "N", num_replicas); - for (int i = 0; i < num_replicas; ++i) { - TFE_OpAddInput(op.get(), components[i], status); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - TFE_OpSetDevice(op.get(), device, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - - TFE_TensorHandle* result_handle; - int num_retvals = 1; - TFE_Execute(op.get(), &result_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return TensorHandlePtr(result_handle); -} - -TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, - TFE_TensorHandle* second, TF_Status* status) { - std::unique_ptr op( - TFE_NewOp(context, "Mul", status), TFE_DeleteOp); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpAddInput(op.get(), first, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpAddInput(op.get(), second, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - const char* first_device = TFE_TensorHandleDeviceName(first, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetDevice(op.get(), first_device, status); - - TFE_TensorHandle* result_handle; - int num_retvals = 1; - TFE_Execute(op.get(), &result_handle, &num_retvals, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - return TensorHandlePtr(result_handle); -} - -// Assert that `handle` is equal to `expected_value`. -void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - std::unique_ptr value_zero( - TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - ASSERT_EQ(expected_value, - *static_cast(TF_TensorData(value_zero.get()))); -} - -template -void RegisterParallelDevice( - TFE_Context* context, const char* device_name, - const std::array& underlying_devices, - TF_Status* status) { - TFE_CustomDevice device; - void* device_info; - tensorflow::eager::AllocateParallelDevice( - device_name, underlying_devices.data(), underlying_devices.size(), - &device, &device_info); - TFE_RegisterCustomDevice(context, device, device_name, device_info, status); -} - -// Create and modify a variable placed on a parallel device which composes -// `first_device` and `second_device`. -void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, - const char* second_device) { - // Register the custom device - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; - std::array underlying_devices{first_device, second_device}; - RegisterParallelDevice(context, device_name, underlying_devices, - status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - // Create a variable handle (uninitialized to start) placed on the parallel - // device. - std::function variable_deleter = [&](Variable* to_delete) { - to_delete->Destroy(context, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - delete to_delete; - }; - std::unique_ptr variable( - Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name, - status.get()), - variable_deleter); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - // Assign an initial value to the variable, implicitly mirroring it to each - // component device. - { - TensorHandlePtr initial_value = FloatTensorHandle(20., status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - variable->Assign(context, initial_value.get(), status.get()); - } - - // Read from the variable and verify that we have a parallel tensor. - { - TensorHandlePtr read = variable->Read(context, status.get()); - std::array components; - ExtractPerDeviceValues(context, read.get(), &components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - AssertScalarFloatEq(components[0].get(), 20.); - AssertScalarFloatEq(components[1].get(), 20.); - - std::string first_device = - TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); - ASSERT_EQ(underlying_devices[0], first_device); - std::string second_device = - TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); - ASSERT_EQ(underlying_devices[1], second_device); - } - - // Add a parallel tensor with different values on each device to the variable. - { - TensorHandlePtr value_one(FloatTensorHandle(3., status.get())); - TensorHandlePtr value_two(FloatTensorHandle(-2., status.get())); - std::array components{value_one.get(), - value_two.get()}; - TensorHandlePtr combined_value = - CreatePerDeviceValues(context, components, device_name, status.get()); - variable->AssignAdd(context, combined_value.get(), status.get()); - } - - // Read the variable and verify that each component has the right modified - // value. - { - TensorHandlePtr read = variable->Read(context, status.get()); - std::array components; - ExtractPerDeviceValues(context, read.get(), &components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - - AssertScalarFloatEq(components[0].get(), 23.); - AssertScalarFloatEq(components[1].get(), 18.); - - std::string first_device = - TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); - ASSERT_EQ(underlying_devices[0], first_device); - std::string second_device = - TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); - ASSERT_EQ(underlying_devices[1], second_device); - } -} - TEST(PARALLEL_DEVICE, TestBasicCPU) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -498,8 +142,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // The value of the original tensor is replicated on each device. - AssertScalarFloatEq(components[0].get(), 3.); - AssertScalarFloatEq(components[1].get(), 3.); + ExpectScalarEq(components[0].get(), 3.); + ExpectScalarEq(components[1].get(), 3.); // Verify that the mirrors are placed on the component devices. std::string first_device = @@ -630,7 +274,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { &second_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(second_components[1].get(), 9.); + ExpectScalarEq(second_components[1].get(), 9.); // Verify that the mirrors are placed on the component devices. std::string first_device = TFE_TensorHandleBackingDeviceName( @@ -644,8 +288,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { std::array first_components; ExtractPerDeviceValues(context.get(), second_components[0].get(), &first_components, status.get()); - AssertScalarFloatEq(first_components[0].get(), 3.); - AssertScalarFloatEq(first_components[1].get(), 6.); + ExpectScalarEq(first_components[0].get(), 3.); + ExpectScalarEq(first_components[1].get(), 6.); first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(), status.get()); @@ -806,8 +450,8 @@ TEST(PARALLEL_DEVICE, TestCollective) { ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(result_components[0].get(), 3.); - AssertScalarFloatEq(result_components[1].get(), 3.); + ExpectScalarEq(result_components[0].get(), 3.); + ExpectScalarEq(result_components[1].get(), 3.); } void RegisterCollectiveMulFunction(TFE_Context* context, @@ -909,8 +553,8 @@ TEST(PARALLEL_DEVICE, TestFunction) { ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(result_components[0].get(), 7. * 9.); - AssertScalarFloatEq(result_components[1].get(), 7. * 9.); + ExpectScalarEq(result_components[0].get(), 7. * 9.); + ExpectScalarEq(result_components[1].get(), 7. * 9.); std::string first_device = TFE_TensorHandleBackingDeviceName( result_components[0].get(), status.get()); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc new file mode 100644 index 00000000000..fba47865c36 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc @@ -0,0 +1,308 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/test.h" + +// NOTE(allenl): These tests currently go through TFE_Execute and so are +// integration testing rather than purely testing the parallel device. They +// correspond fairly well to the implementation, but testing the C++ directly is +// another option. + + +Variable* Variable::Create(TFE_Context* context, TF_DataType type, + const int64_t* dims, const int num_dims, + const char* device, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op.get(), "dtype", type); + TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status); + TFE_OpSetAttrString(op.get(), "container", "", 0); + // Use the special GUID for no buffer sharing + // + // TODO(allenl): Should we provide a better API for this? AFAIK this is the + // only reasonable way to make variables with no aliasing using the eager C + // API. + std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(), + no_sharing.length()); + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op.get(), &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return new Variable(var_handle, type); +} + +void Variable::Destroy(TFE_Context* context, TF_Status* status) { + // Free the backing buffer for the variable. + std::unique_ptr op( + TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return; + int num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; + // Delete the variable handle itself. + TFE_DeleteTensorHandle(handle_); +} + +TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op.get(), "dtype", type_); + int num_retvals = 1; + TFE_TensorHandle* var_value = nullptr; + TFE_Execute(op.get(), &var_value, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(var_value); +} + +void Variable::GeneralAssignment(const char* op_name, TFE_Context* context, + TFE_TensorHandle* value, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, op_name, status), &TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetAttrType(op.get(), "dtype", type_); + TFE_OpAddInput(op.get(), handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpAddInput(op.get(), value, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(handle_, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + + int num_retvals = 0; + TFE_Execute(op.get(), nullptr, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; +} + +void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status) { + GeneralAssignment("AssignAddVariableOp", context, value, status); +} + +void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status) { + GeneralAssignment("AssignVariableOp", context, value, status); +} + +// Passed to `TF_NewTensor` to indicate how an array of floats should be +// deleted. +static void FloatDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + +// Creates a TFE_TensorHandle with value `v`. +TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) { + const int num_bytes = sizeof(float); + float* values = new float[1]; + values[0] = v; + std::unique_ptr tensor( + TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator, + nullptr), + TF_DeleteTensor); + return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); +} + +// Creates a rank-one TFE_TensorHandle with value `v`. +TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, + TF_Status* status) { + const int num_bytes = v.size() * sizeof(float); + float* values = new float[v.size()]; + memcpy(values, v.data(), num_bytes); + int64_t dims = v.size(); + std::unique_ptr tensor( + TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes, + &FloatDeallocator, nullptr), + TF_DeleteTensor); + return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status)); +} + +// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. +template +void ExtractPerDeviceValues( + TFE_Context* context, TFE_TensorHandle* input, + std::array* components, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas); + TFE_OpAddInput(op.get(), input, status); + if (TF_GetCode(status) != TF_OK) return; + const char* device = TFE_TensorHandleDeviceName(input, status); + if (TF_GetCode(status) != TF_OK) return; + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return; + + TFE_TensorHandle* result_handles[num_replicas]; + int num_retvals = num_replicas; + TFE_Execute(op.get(), result_handles, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return; + for (int i = 0; i < num_replicas; ++i) { + (*components)[i].reset(result_handles[i]); + } +} + +TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, + TFE_TensorHandle* second, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "Mul", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), first, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpAddInput(op.get(), second, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + const char* first_device = TFE_TensorHandleDeviceName(first, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(op.get(), first_device, status); + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +// Create and modify a variable placed on a parallel device which composes +// `first_device` and `second_device`. +void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, + const char* second_device) { + // Register the custom device + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::array underlying_devices{first_device, second_device}; + RegisterParallelDevice(context, device_name, underlying_devices, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Create a variable handle (uninitialized to start) placed on the parallel + // device. + std::function variable_deleter = [&](Variable* to_delete) { + to_delete->Destroy(context, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + delete to_delete; + }; + std::unique_ptr variable( + Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name, + status.get()), + variable_deleter); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // Assign an initial value to the variable, implicitly mirroring it to each + // component device. + { + TensorHandlePtr initial_value = FloatTensorHandle(20., status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + variable->Assign(context, initial_value.get(), status.get()); + } + + // Read from the variable and verify that we have a parallel tensor. + { + TensorHandlePtr read = variable->Read(context, status.get()); + std::array components; + ExtractPerDeviceValues(context, read.get(), &components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 20.); + ExpectScalarEq(components[1].get(), 20.); + + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } + + // Add a parallel tensor with different values on each device to the variable. + { + TensorHandlePtr value_one(FloatTensorHandle(3., status.get())); + TensorHandlePtr value_two(FloatTensorHandle(-2., status.get())); + std::array components{value_one.get(), + value_two.get()}; + TensorHandlePtr combined_value = + CreatePerDeviceValues(context, components, device_name, status.get()); + variable->AssignAdd(context, combined_value.get(), status.get()); + } + + // Read the variable and verify that each component has the right modified + // value. + { + TensorHandlePtr read = variable->Read(context, status.get()); + std::array components; + ExtractPerDeviceValues(context, read.get(), &components, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 23.); + ExpectScalarEq(components[1].get(), 18.); + + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } + // Compute the device ID twice and verify the result + for (int i = 0; i < 2; ++i) { + std::unique_ptr op( + TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetDevice(op.get(), device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components; + ExtractPerDeviceValues(context, result_handle, &components, status.get()); + TFE_DeleteTensorHandle(result_handle); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 0); + ExpectScalarEq(components[1].get(), 1); + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } +} diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h new file mode 100644 index 00000000000..fdd21087949 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h @@ -0,0 +1,174 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ +#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ + +#include "tensorflow/c/eager/parallel_device/parallel_device.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/test.h" + + +// Functor for making unique_ptr to TFE_TensorHandle slightly more +// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second +// template argument requires passing a function pointer to +// TFE_DeleteTensorHandle when constructing the unique_ptr. +class TensorHandleDeleter { + public: + void operator()(TFE_TensorHandle* to_delete) { + TFE_DeleteTensorHandle(to_delete); + } +}; + +using TensorHandlePtr = std::unique_ptr; + +// A helper for performing common operations on variables. A much more +// restricted stand-in for tf.Variable in Python. +class Variable { + public: + // Construct a Variable from a resource-dtype TFE_TensorHandle and an + // indication of the dtype of the variable's value. + // + // Note that creating this resource-dtype handle can fail, so `Create` is a + // separate static method which returns a status. + Variable(TFE_TensorHandle* handle, TF_DataType type) + : handle_(handle), type_(type) {} + + // Helper for constructing a resource handle and wrapping it in a `Variable` + // object. + static Variable* Create(TFE_Context* context, TF_DataType type, + const int64_t* dims, const int num_dims, + const char* device, TF_Status* status); + // Dereferences the backing buffer for the variable. Note that since this can + // fail (it runs operations), it must be called explicitly and the resulting + // `status` checked. + void Destroy(TFE_Context* context, TF_Status* status); + + // Reads from the variable. + TensorHandlePtr Read(TFE_Context* context, TF_Status* status); + // Assigns a new value to the variable. + void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status); + // Adds `value` to the existing value of the variable. + void AssignAdd(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status); + + private: + // Helper for running any single-argument assignment ops (Assign, AssignAdd, + // AssignSub, ...). + void GeneralAssignment(const char* op_name, TFE_Context* context, + TFE_TensorHandle* value, TF_Status* status); + + // The a handle for the resource-dtype tensor pointing to the variable's + // buffer. + TFE_TensorHandle* handle_; + // The dtype of the variable's buffer (input dtype for assignments, output + // dtype of read operations). + TF_DataType type_; +}; + +// Creates a TFE_TensorHandle with value `v`. +TensorHandlePtr FloatTensorHandle(float v, TF_Status* status); + +// Creates a rank-one TFE_TensorHandle with value `v`. +TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, + TF_Status* status); + +// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. +template +void ExtractPerDeviceValues( + TFE_Context* context, TFE_TensorHandle* input, + std::array* components, TF_Status* status); + +// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle. +template +TensorHandlePtr CreatePerDeviceValues( + TFE_Context* context, + const std::array& components, + const char* device, TF_Status* status); + +TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, + TFE_TensorHandle* second, TF_Status* status); + +// Assert that `handle` is equal to `expected_value`. +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value); + +template +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array& underlying_devices, + TF_Status* status); + +// Create and modify a variable placed on a parallel device which composes +// `first_device` and `second_device`. +void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, + const char* second_device); + +// Implementations of templated functions ****************************** + +template +TensorHandlePtr CreatePerDeviceValues( + TFE_Context* context, + const std::array& components, + const char* device, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrInt(op.get(), "N", num_replicas); + for (int i = 0; i < num_replicas; ++i) { + TFE_OpAddInput(op.get(), components[i], status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr value_zero( + TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + EXPECT_EQ(expected_value, + *static_cast(TF_TensorData(value_zero.get()))); +} + +template +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array& underlying_devices, + TF_Status* status) { + TFE_CustomDevice device; + void* device_info; + tensorflow::eager::AllocateParallelDevice( + device_name, underlying_devices.data(), underlying_devices.size(), + &device, &device_info); + TFE_RegisterCustomDevice(context, device, device_name, device_info, status); +} + +#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD new file mode 100644 index 00000000000..34142fec5f7 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -0,0 +1,30 @@ +# Experimental gcs filesystem plugin. +load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") + +package( + licenses = ["notice"], # Apache 2.0 +) + +# Filesystem implementation for GCS environments +tf_cc_shared_object( + name = "gcs_filesystem", + framework_so = [], + linkstatic = False, + per_os_targets = 1, + visibility = ["//visibility:public"], + deps = [":gcs_filesystem_impl"], +) + +# The real implementation of the filesystem. +cc_library( + name = "gcs_filesystem_impl", + srcs = ["gcs_filesystem.cc"], + copts = select({ + "//conditions:default": [], + "//tensorflow:windows": get_win_copts(), + }), + deps = [ + "//tensorflow/c:tf_status", + "//tensorflow/c/experimental/filesystem:filesystem_interface", + ], +) diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc new file mode 100644 index 00000000000..ea9f59f1af3 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/tf_status.h" + +// Implementation of a filesystem for GCS environments. +// This filesystem will support `gs://` URI schemes. + +static void* plugin_memory_allocate(size_t size) { return calloc(1, size); } +static void plugin_memory_free(void* ptr) { free(ptr); } + +// SECTION 1. Implementation for `TF_RandomAccessFile` +// ---------------------------------------------------------------------------- +namespace tf_random_access_file { + +// TODO(vnvo2409): Implement later + +} // namespace tf_random_access_file + +// SECTION 2. Implementation for `TF_WritableFile` +// ---------------------------------------------------------------------------- +namespace tf_writable_file { + +// TODO(vnvo2409): Implement later + +} // namespace tf_writable_file + +// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` +// ---------------------------------------------------------------------------- +namespace tf_read_only_memory_region { + +// TODO(vnvo2409): Implement later + +} // namespace tf_read_only_memory_region + +// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem +// ---------------------------------------------------------------------------- +namespace tf_gcs_filesystem { + +// TODO(vnvo2409): Implement later + +} // namespace tf_gcs_filesystem + +static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, + const char* uri) { + TF_SetFilesystemVersionMetadata(ops); + ops->scheme = strdup(uri); +} + +void TF_InitPlugin(TF_FilesystemPluginInfo* info) { + info->plugin_memory_allocate = plugin_memory_allocate; + info->plugin_memory_free = plugin_memory_free; + info->num_schemes = 1; + info->ops = static_cast( + plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0]))); + ProvideFilesystemSupportFor(&info->ops[0], "gs"); +} \ No newline at end of file diff --git a/tensorflow/c/experimental/network.cc b/tensorflow/c/experimental/network.cc index 94375cf9983..97e63ec6259 100644 --- a/tensorflow/c/experimental/network.cc +++ b/tensorflow/c/experimental/network.cc @@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory { delete_function_(delete_function), rendezvous_builder_(rendezvous_builder) {} - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { TF_RETURN_IF_ERROR(CGrpcServer::Create( server_def, init_function_, start_function_, stop_function_, diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 7a694f4f803..2ded784882b 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -31,9 +31,6 @@ cc_library( "//tensorflow/c/experimental/saved_model/public:concrete_function.h", ], copts = tf_copts(), - # TODO(bmzhao): Remove this as we refactor C API to granular targets, - # so that we can depend on c/eager/c_api_unified_experimental.h. - features = ["-layering_check"], visibility = [ "//tensorflow/c/experimental/saved_model/public:__pkg__", ], @@ -41,6 +38,8 @@ cc_library( ":concrete_function_type", ":function_metadata", ":function_metadata_type", + ":tensorhandle_list", + ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", @@ -156,10 +155,43 @@ cc_library( "saved_model_api_type.h", ], deps = [ + "//tensorflow/c:conversion_macros", "//tensorflow/c/experimental/saved_model/core:saved_model_api", ], ) +cc_library( + name = "tensorhandle_list", + srcs = [ + "tensorhandle_list.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":tensorhandle_list_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:tfe_tensorhandle_internal", + ], +) + +cc_library( + name = "tensorhandle_list_type", + hdrs = [ + "tensorhandle_list_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/eager:tensor_handle_interface", + ], +) + tf_cc_test( name = "saved_model_api_test", size = "small", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 4884f9e2e97..dd54416ddf9 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" +#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" extern "C" { @@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { &tensorflow::unwrap(func)->GetFunctionMetadata())); } -TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) { - // TODO(bmzhao): Refactor TF_OutputList struct definition into a separate - // internal header, and implement this function. - return nullptr; +const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( + TF_ConcreteFunction* func) { + return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); } TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index 629610dbe29..9614e507646 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, if (!status->status.ok()) { return nullptr; } - return new TF_SavedModel{std::move(result)}; + return tensorflow::wrap(result.release()); } TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, @@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, if (!status->status.ok()) { return nullptr; } - return new TF_SavedModel{std::move(result)}; + return tensorflow::wrap(result.release()); } -void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; } +void TF_DeleteSavedModel(TF_SavedModel* model) { + delete tensorflow::unwrap(model); +} TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, const char* function_path, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = - model->saved_model->GetFunction(function_path, &result); + tensorflow::unwrap(model)->GetFunction(function_path, &result); status->status.Update(get_function_status); if (!get_function_status.ok()) { return nullptr; @@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = - model->saved_model->GetSignatureDefFunction(signature_def_key, &result); + tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key, + &result); status->status.Update(get_function_status); if (!get_function_status.ok()) { return nullptr; @@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( } TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) { - return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()}; + return new TF_ConcreteFunctionList{ + tensorflow::unwrap(model)->ListFunctions()}; } } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h index 9e2d1117463..380c3703426 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h @@ -18,13 +18,18 @@ limitations under the License. #include +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" // Internal structures used by the SavedModel C API. These are likely to change // and should not be depended on. -struct TF_SavedModel { - std::unique_ptr saved_model; -}; +typedef struct TF_SavedModel TF_SavedModel; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel) + +} // namespace tensorflow #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc new file mode 100644 index 00000000000..7d018658101 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" + +#include + +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" + +extern "C" { + +size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) { + return tensorflow::unwrap(list)->size(); +} + +TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list, + int i) { + return tensorflow::wrap((*tensorflow::unwrap(list))[i]); +} + + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h new file mode 100644 index 00000000000..8cbec2806a8 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" + +// Internal structures used by the SavedModel C API. These are likely to +// change and should not be depended on. + +typedef struct TF_TensorHandleList TF_TensorHandleList; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS( + std::vector, + TF_TensorHandleList) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index af65e05e7f6..0cfa0a2c005 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -24,6 +24,7 @@ exports_files( "concrete_function_list.h", "function_metadata.h", "saved_model_api.h", + "tensorhandle_list.h", ], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], ) @@ -39,6 +40,7 @@ cc_library( ":concrete_function_list", ":function_metadata", ":saved_model_api", + ":tensorhandle_list", ], ) @@ -61,3 +63,8 @@ alias( name = "saved_model_api", actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", ) + +alias( + name = "tensorhandle_list", + actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list", +) diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h index 30f533f140a..aae95a5477c 100644 --- a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" // IWYU pragma: end_exports #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 351d8daed8e..2a87214270c 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ #include "tensorflow/c/c_api_macros.h" -#include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" #ifdef __cplusplus extern "C" { @@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( TF_ConcreteFunction* func); // Returns a list of TensorHandles implicitly captured by this function. -TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures( +TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( TF_ConcreteFunction* func); // Returns a TFE_Op suitable for executing this function. diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function_list.h b/tensorflow/c/experimental/saved_model/public/concrete_function_list.h index 7add847259c..e35546751f1 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function_list.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function_list.h @@ -21,19 +21,27 @@ limitations under the License. #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + // An opaque type that is acts like a list of TF_ConcreteFunction pointers. typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList; // Returns the size of `list`. -TF_CAPI_EXPORT size_t -TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list); +TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize( + TF_ConcreteFunctionList* list); // Returns the `i`th TF_ConcreteFunction in the list. -TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet( +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet( TF_ConcreteFunctionList* list, int i); // Deletes `list`. -TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList( +TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList( TF_ConcreteFunctionList* list); +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h b/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h new file mode 100644 index 00000000000..a1e88db3474 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that is acts like a list of TF_ConcreteFunction pointers. +typedef struct TF_TensorHandleList TF_TensorHandleList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize( + const TF_TensorHandleList* list); + +// Returns the `i`th TFE_TensorHandle in the list. +TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet( + const TF_TensorHandleList* list, int i); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index e8cb40f153b..e1fad8e697a 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -178,7 +178,7 @@ cc_library_with_android_deps( name = "ops", srcs = ["framework/ops.cc"], hdrs = ["framework/ops.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -197,7 +197,7 @@ cc_library_with_android_deps( "framework/scope_internal.h", ], hdrs = ["framework/scope.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], common_deps = [ ":ops", ], @@ -237,7 +237,7 @@ cc_library_with_android_deps( name = "client_session", srcs = ["client/client_session.cc"], hdrs = ["client/client_session.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], common_deps = [ ":ops", ":scope", @@ -275,7 +275,7 @@ cc_library_with_android_deps( srcs = ["ops/const_op.cc"], hdrs = ["ops/const_op.h"], android_deps = [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], common_deps = [ ":ops", @@ -304,7 +304,7 @@ cc_library_with_android_deps( srcs = ["ops/while_loop.cc"], hdrs = ["ops/while_loop.h"], android_deps = [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], common_deps = [ ":cc_ops", diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD index 4249d7918c8..045d4e6cd97 100644 --- a/tensorflow/cc/experimental/base/public/BUILD +++ b/tensorflow/cc/experimental/base/public/BUILD @@ -57,7 +57,22 @@ cc_library( "tensor.h", ], deps = [ + ":status", "//tensorflow/c:tf_datatype", "//tensorflow/c:tf_tensor", ], ) + +cc_library( + name = "tensorhandle", + hdrs = [ + "tensorhandle.h", + ], + deps = [ + ":runtime", + ":status", + ":tensor", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) diff --git a/tensorflow/cc/experimental/base/public/runtime.h b/tensorflow/cc/experimental/base/public/runtime.h index 47fd8869647..711a38c233a 100644 --- a/tensorflow/cc/experimental/base/public/runtime.h +++ b/tensorflow/cc/experimental/base/public/runtime.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" namespace tensorflow { +namespace experimental { namespace cc { // Runtime represents an opaque instance of a Tensorflow runtime, with its own @@ -40,6 +41,7 @@ class Runtime { private: friend class RuntimeBuilder; friend class SavedModelAPI; + friend class TensorHandle; // Wraps a TFE_Context. Takes ownership of ctx. explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} @@ -63,6 +65,7 @@ class Runtime { }; } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ diff --git a/tensorflow/cc/experimental/base/public/runtime_builder.h b/tensorflow/cc/experimental/base/public/runtime_builder.h index ed3c93ae135..737e06cb2c6 100644 --- a/tensorflow/cc/experimental/base/public/runtime_builder.h +++ b/tensorflow/cc/experimental/base/public/runtime_builder.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/experimental/base/public/status.h" namespace tensorflow { +namespace experimental { namespace cc { // RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. @@ -79,6 +80,7 @@ inline std::unique_ptr RuntimeBuilder::Build(Status* status) { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ diff --git a/tensorflow/cc/experimental/base/public/status.h b/tensorflow/cc/experimental/base/public/status.h index f91f2caccd8..98c8cf6ced2 100644 --- a/tensorflow/cc/experimental/base/public/status.h +++ b/tensorflow/cc/experimental/base/public/status.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" namespace tensorflow { +namespace experimental { namespace cc { // Status is a wrapper around an error code and an optional error message. @@ -57,6 +58,7 @@ class Status { friend class RuntimeBuilder; friend class Runtime; friend class SavedModelAPI; + friend class TensorHandle; // Wraps a TF_Status*, and takes ownership of it. explicit Status(TF_Status* status) : status_(status) {} @@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ diff --git a/tensorflow/cc/experimental/base/public/tensor.h b/tensorflow/cc/experimental/base/public/tensor.h index 1afdbcad50c..fc447262ce1 100644 --- a/tensorflow/cc/experimental/base/public/tensor.h +++ b/tensorflow/cc/experimental/base/public/tensor.h @@ -19,30 +19,53 @@ limitations under the License. #include #include +#include #include +#include #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_tensor.h" +#include "tensorflow/cc/experimental/base/public/status.h" namespace tensorflow { +namespace experimental { namespace cc { // Tensor represents an n-dimensional array of values. class Tensor { public: - // TODO(bmzhao): Add a factory function that constructs a Tensor from a char - // buffer, with an options struct (to specify the buffer's layout, device?, - // whether to create a TFRT or TF tensor, whether we should take ownership of - // the memory, etc). This requires extending TF_NewTensor with an options - // struct: - // https://github.com/tensorflow/tensorflow/blob/3c520614a3c056d56afdc79b59979b9b0087f8b9/tensorflow/c/tf_tensor.h#L77-L80 + using DeleterCallback = std::function; + + // Constructs a Tensor from user provided buffer. + // + // Params: + // dtype - The dtype of the tensor's data. + // shape - A shape vector, where each element corresponds to the size of + // the tensor's corresponding dimension. + // data - Pointer to a buffer of memory to construct a Tensor out of. + // len - The length (in bytes) of `data` + // deleter - A std::function to be called when the Tensor no longer needs the + // memory in `data`. This can be used to free `data`, or + // perhaps decrement a refcount associated with `data`, etc. + // status - Set to OK on success and an error on failure. + // Returns: + // If an error occurred, status->ok() will be false, and the returned + // Tensor must not be used. + // TODO(bmzhao): Add Runtime as an argument to this function so we can swap to + // a TFRT backed tensor. + // TODO(bmzhao): Add benchmarks on overhead for this function; we can + // consider using int64_t* + length rather than vector. + static Tensor FromBuffer(TF_DataType dtype, const std::vector& shape, + void* data, size_t len, DeleterCallback deleter, + Status* status); // TODO(bmzhao): In the case we construct a tensor from non-owned memory, // we should offer a way to deep copy the tensor into a new tensor, which // owns the underlying memory. This could be a .deepcopy()/clone() method. // TODO(bmzhao): In the future, we want to relax the non-copyability - // constraint. To do so, we can add a C API function that acts like CopyFrom: + // constraint. To do so, we can add a C API function that acts like + // CopyFrom: // https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311 // Tensor is movable, but not copyable @@ -85,6 +108,16 @@ class Tensor { // This object retains ownership of the pointer. TF_Tensor* GetTFTensor() const { return tensor_.get(); } + struct DeleterStruct { + std::function deleter; + }; + + static void DeleterFunction(void* memory, size_t len, void* deleter_struct) { + DeleterStruct* deleter = reinterpret_cast(deleter_struct); + deleter->deleter(memory, len); + delete deleter; + } + struct TFTensorDeleter { void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } }; @@ -111,7 +144,32 @@ inline size_t Tensor::num_bytes() const { return TF_TensorByteSize(tensor_.get()); } +inline Tensor Tensor::FromBuffer(TF_DataType dtype, + const std::vector& shape, void* data, + size_t len, DeleterCallback deleter, + Status* status) { + // Credit to apassos@ for this technique: + // Despite the fact that our API takes a std::function deleter, we are able + // to maintain ABI stability because: + // 1. Only a function pointer is sent across the C API (&DeleterFunction) + // 2. DeleterFunction is defined in the same build artifact that constructed + // the std::function (so there isn't confusion about std::function ABI). + // Note that 2. is satisifed by the fact that this is a header-only API, where + // the function implementations are inline. + + DeleterStruct* deleter_struct = new DeleterStruct{deleter}; + TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len, + &DeleterFunction, deleter_struct); + if (tensor == nullptr) { + status->SetStatus(TF_INVALID_ARGUMENT, + "Failed to create tensor for input buffer"); + return Tensor(nullptr); + } + return Tensor(tensor); +} + } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ diff --git a/tensorflow/cc/experimental/base/public/tensorhandle.h b/tensorflow/cc/experimental/base/public/tensorhandle.h new file mode 100644 index 00000000000..99453ee7ea8 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/tensorhandle.h @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ + +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// An opaque representation of a tensor computed/managed by the Tensorflow +// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer +// to tensors placed in memory of different devices or remote address spaces. +// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created +// from it. +class TensorHandle { + public: + // Unwraps a Tensor from the given TensorHandle. If an error occurred, + // status->ok() will be false, and the returned Tensor must not be used. + Tensor Resolve(Status* status); + + // Constructs a TensorHandle from a Tensor. If an error occurred, + // status->ok() will be false, and the returned TensorHandle must not be used. + static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime, + Status* status); + + // TensorHandle is movable, and not copyable + TensorHandle(TensorHandle&&) = default; + TensorHandle& operator=(TensorHandle&&) = default; + + private: + // Wraps a TFE_TensorHandle. Takes ownership of handle. + explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {} + + // TensorHandle is not copyable + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + // Returns the underlying TFE_TensorHandle that this object wraps. + // This object retains ownership of the pointer. + TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); } + + // Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle, + // and takes ownership of handle. + void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); } + + struct TFETensorHandleDeleter { + void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); } + }; + std::unique_ptr handle_; +}; + +inline Tensor TensorHandle::Resolve(Status* status) { + TF_Tensor* tensor = + TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus()); + if (!status->ok()) { + return Tensor(nullptr); + } + return Tensor(tensor); +} + +inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor, + const Runtime& runtime, + Status* status) { + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor( + runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus()); + if (!status->ok()) { + return TensorHandle(nullptr); + } + return TensorHandle(tensor_handle); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD new file mode 100644 index 00000000000..f449d618f72 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/BUILD @@ -0,0 +1,50 @@ +# Tests for the C++ header-only base types. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tensor_types_test_util", + testonly = True, + hdrs = ["tensor_types_test_util.h"], + deps = [ + "//tensorflow/c:tf_datatype", + ], +) + +tf_cc_test( + name = "tensor_test", + srcs = [ + "tensor_test.cc", + ], + deps = [ + ":tensor_types_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/experimental/base/public:tensor", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "tensorhandle_test", + srcs = [ + "tensorhandle_test.cc", + ], + deps = [ + ":tensor_types_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:runtime_builder", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/experimental/base/public:tensor", + "//tensorflow/cc/experimental/base/public:tensorhandle", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc new file mode 100644 index 00000000000..33f9ab637e8 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/experimental/base/public/tensor.h" + +#include +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; + +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; + +template +class ConstructScalarTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes); + +// This test constructs a scalar tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + typename TypeParam::type value = 42; + Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 0); + EXPECT_EQ(tensor.dtype(), dtype); + EXPECT_EQ(*reinterpret_cast(tensor.data()), 42); + EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), 1); +} + +template +class Construct1DTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes); + +// This test constructs a 1D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 1 vector. + std::vector shape; + shape.push_back(value.size()); + + Tensor tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 1); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +template +class Construct2DTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes); + +// This test constructs a 2D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 2 vector with shape 2 x 3. + std::vector shape({2, 3}); + + Tensor tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 2); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +TEST(CPPTensorAPI, ConstructTensorFromBuffer) { + bool done = false; + Status status; + std::vector data_vector({12, 14, 20, 18, 39, 42, 100}); + { + // data_vector is a rank 1 tensor. + std::vector shape; + shape.push_back(data_vector.size()); + + Tensor::DeleterCallback callback = [&done](void* data, size_t len) { + done = true; + }; + + Tensor tensor = + Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape, + /*data=*/data_vector.data(), + /*len=*/data_vector.size() * sizeof(int32_t), + /*deleter=*/callback, &status); + ASSERT_TRUE(status.ok()) << status.message(); + } + // At this point, tensor has been destroyed, and the deleter callback should + // have run. + EXPECT_TRUE(done); +} + +} // namespace diff --git a/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h new file mode 100644 index 00000000000..af9cad7529b --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ + +#include + +#include "tensorflow/c/tf_datatype.h" + +namespace tensorflow { + +// Each of the following struct types have two members: a kDType that +// corresponds to a TF_Datatype enum value, and a typedef "type" +// of its corresponding C++ type. These types allow us to write Dtype-agnostic +// tests via GoogleTest's TypedTests: +// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests +struct FloatType { + using type = float; + static constexpr TF_DataType kDType = TF_FLOAT; +}; + +struct DoubleType { + using type = double; + static constexpr TF_DataType kDType = TF_DOUBLE; +}; + +struct Int32Type { + using type = int32_t; + static constexpr TF_DataType kDType = TF_INT32; +}; + +struct UINT8Type { + using type = uint8_t; + static constexpr TF_DataType kDType = TF_UINT8; +}; + +struct INT8Type { + using type = int8_t; + static constexpr TF_DataType kDType = TF_INT8; +}; + +struct INT64Type { + using type = int64_t; + static constexpr TF_DataType kDType = TF_INT64; +}; + +struct UINT16Type { + using type = uint16_t; + static constexpr TF_DataType kDType = TF_UINT16; +}; + +struct UINT32Type { + using type = uint32_t; + static constexpr TF_DataType kDType = TF_UINT32; +}; + +struct UINT64Type { + using type = uint64_t; + static constexpr TF_DataType kDType = TF_UINT64; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc new file mode 100644 index 00000000000..cfeaba4e392 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/experimental/base/public/tensorhandle.h" + +#include +#include + +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/runtime_builder.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; +using tensorflow::experimental::cc::TensorHandle; + +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; + +template +class ConstructScalarTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes); + +// This test constructs a scalar tensor for each of the types in "SimpleTypes", +// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and +// verify the expected dims, dtype, value, num bytes, and num elements. +TYPED_TEST(ConstructScalarTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + typename TypeParam::type value = 42; + Tensor original_tensor = + Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 0); + EXPECT_EQ(tensor.dtype(), dtype); + EXPECT_EQ(*reinterpret_cast(tensor.data()), 42); + EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), 1); +} + +template +class Construct1DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes); + +// This test constructs a 1D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct1DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 1 vector. + std::vector shape; + shape.push_back(value.size()); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 1); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +template +class Construct2DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes); + +// This test constructs a 2D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct2DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 2 vector with shape 2 x 3. + std::vector shape({2, 3}); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 2); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 882b4032f76..b13d8db48a9 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -4,7 +4,6 @@ load( "//tensorflow:tensorflow.bzl", "if_android", - "if_ios", "if_mobile", "if_not_mobile", "tf_cc_test", @@ -85,7 +84,7 @@ cc_library( "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ]) + if_android([ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ]), ) diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/tensorflow/cc/saved_model/experimental/public/concrete_function.h index f57ba052f1a..1adaf70b01a 100644 --- a/tensorflow/cc/saved_model/experimental/public/concrete_function.h +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" namespace tensorflow { +namespace experimental { namespace cc { // ConcreteFunction is an executable "function" loaded from a SavedModelAPI. @@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h index bab95278eac..88cb779ef15 100644 --- a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" namespace tensorflow { +namespace experimental { namespace cc { // ConcreteFunctionList helps convert an opaque pointer to an array of @@ -56,6 +57,7 @@ inline std::vector ConcreteFunctionList::ToVector() { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/tensorflow/cc/saved_model/experimental/public/function_metadata.h index c3dcc45af0e..11e1a860d84 100644 --- a/tensorflow/cc/saved_model/experimental/public/function_metadata.h +++ b/tensorflow/cc/saved_model/experimental/public/function_metadata.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" namespace tensorflow { +namespace experimental { namespace cc { // FunctionMetadata stores additional function information, including @@ -40,6 +41,7 @@ class FunctionMetadata final { }; } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h index 814479de213..04018bf2aab 100644 --- a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" namespace tensorflow { +namespace experimental { namespace cc { // SavedModelAPI offers a way to load Tensorflow Saved Models @@ -155,6 +156,7 @@ inline std::vector SavedModelAPI::ListFunctions() { } } // namespace cc +} // namespace experimental } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc index 155c58604bf..7f7f6b09a6d 100644 --- a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -26,10 +26,14 @@ limitations under the License. #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" -namespace tensorflow { namespace { +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::SavedModelAPI; +using tensorflow::experimental::cc::Status; + constexpr char kTestData[] = "cc/saved_model/testdata"; std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { @@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { class CPPSavedModelAPITest : public ::testing::TestWithParam {}; TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { - cc::Status status; - cc::RuntimeBuilder builder; + Status status; + RuntimeBuilder builder; bool use_tfrt = GetParam(); if (use_tfrt) { GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. } builder.SetUseTFRT(use_tfrt); - std::unique_ptr runtime = builder.Build(&status); + std::unique_ptr runtime = builder.Build(&status); ASSERT_TRUE(status.ok()) << status.message(); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); std::unordered_set tags = {"serve"}; - std::unique_ptr model = - cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags); + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status, &tags); // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // That unblocks writing other tests that require a TF_SavedModel*, @@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { } TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { - cc::Status status; - cc::RuntimeBuilder builder; + Status status; + RuntimeBuilder builder; bool use_tfrt = GetParam(); if (use_tfrt) { GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. } builder.SetUseTFRT(use_tfrt); - std::unique_ptr runtime = builder.Build(&status); + std::unique_ptr runtime = builder.Build(&status); ASSERT_TRUE(status.ok()) << status.message(); std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); - std::unique_ptr model = - cc::SavedModelAPI::Load(model_dir, *runtime, &status); + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status); // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. // That unblocks writing other tests that require a TF_SavedModel*, @@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests, } // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index c9a36b88795..e4df3090046 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; string dim_sizes, indices; + int count = 1; if (shape.rank() == 0 || (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; @@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, dim_vars.push_back(absl::StrCat("size_t dim", dim)); dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); indices += absl::StrCat("[dim", dim, "]"); + count *= shape.dimensions(dim); } } rewrites->push_back({"{{I}}", absl::StrCat(i)}); @@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); + rewrites->push_back({"{{COUNT}}", absl::StrCat(count)}); return Status::OK(); } @@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config, return (*static_cast( arg_data({{I}}))){{INDICES}}; } + int arg{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int arg{{NAME}}_count() const { + return {{COUNT}}; + } )"; *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.feed(i).name().empty()) { @@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config, return (*static_cast( result_data({{I}}))){{INDICES}}; } + int result{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int result{{NAME}}_count() const { + return {{COUNT}}; + } )"; *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.fetch(i).name().empty()) { @@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config, return (*static_cast( arg_data({{I}}))){{INDICES}}; } + int var_{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int var_{{NAME}}_count() const { + return {{COUNT}}; + } )"; const tf2xla::Variable& var = config.variable(i - config.feed_size()); rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : ""); diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index af58ca233f0..d011279dbb7 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(0)))[dim0][dim1]; } + int arg0_size() const { + return 2 * sizeof(float); + } + int arg0_count() const { + return 2; + } void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); @@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(0)))[dim0][dim1]; } + int arg_myfeed_size() const { + return 2 * sizeof(float); + } + int arg_myfeed_count() const { + return 2; + } void set_arg1_data(const void* data) { set_arg_data(1, data); @@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(1)))[dim0][dim1]; } + int arg1_size() const { + return 12 * sizeof(tensorflow::int64); + } + int arg1_count() const { + return 12; + } // Result methods for managing output buffers. Buffers are in row-major order. // Must only be called after a successful Run call. There is a set of methods @@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( result_data(0)))[dim0][dim1]; } + int result0_size() const { + return 30 * sizeof(tensorflow::uint32); + } + int result0_count() const { + return 30; + } tensorflow::uint32* result_myfetch_data() { return static_cast(result_data(0)); @@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( result_data(0)))[dim0][dim1]; } + int result_myfetch_size() const { + return 30 * sizeof(tensorflow::uint32); + } + int result_myfetch_count() const { + return 30; + } // Methods for managing variable buffers. Buffers are in row-major order. // @@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(2)))[0]; } + int var_myvar_readonly_size() const { + return 1 * sizeof(float); + } + int var_myvar_readonly_count() const { + return 1; + } void set_var_myvar_data(float* data) { set_arg_data(3, data); @@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(3)))[0]; } + int var_myvar_size() const { + return 1 * sizeof(float); + } + int var_myvar_count() const { + return 1; + } void set_var_myvar2_data(tensorflow::int32* data) { set_arg_data(4, data); @@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(4)))[dim0]; } + int var_myvar2_size() const { + return 5 * sizeof(tensorflow::int32); + } + int var_myvar2_count() const { + return 5; + } private: // Number of buffers for the compiled computation. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index abccefbcdbb..f2b28e70ff1 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -20,7 +20,7 @@ load( "tf_cc_test", "tf_copts", ) -load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags") +load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu") def tf_library( name, @@ -42,7 +42,8 @@ def tf_library( mlir_components = "None", deps = None, tags = []): - """Runs tfcompile to compile a TensorFlow graph into executable code. + """Runs tfcompile to compile a TensorFlow graph into executable code with fast + math enabled on cpu. Given an invocation of tf_library(name="foo", ...), generates the following build targets: @@ -187,7 +188,9 @@ def tf_library( # `find` on such an object. need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1 - flags = tfcompile_extra_flags() + flags + target_cpu = tfcompile_target_cpu() + extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " " + flags = extra_flags + flags if enable_xla_hlo_profiling: profiling_flag = "--xla_hlo_profile" @@ -207,6 +210,15 @@ def tf_library( srcs.append(debug_info) debug_info_flag = " --debug_info=$(location " + debug_info + ")" + default_fast_math_xla_flags = ("XLA_FLAGS='" + + "--xla_cpu_enable_fast_math=true " + + "--xla_cpu_fast_math_honor_nans=false " + + "--xla_cpu_fast_math_honor_infs=false " + + "--xla_cpu_fast_math_honor_functions=false " + + "--xla_cpu_fast_math_honor_division=false " + + "--xla_cpu_enable_fast_min_max=true " + + "$${XLA_FLAGS:-}' ") + native.genrule( name = ("gen_" + name), srcs = srcs, @@ -216,6 +228,7 @@ def tf_library( function_object_file, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + @@ -256,6 +269,7 @@ def tf_library( session_module_pb, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index f0cf8f2ded9..846947454bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -67,6 +67,8 @@ int main(int argc, char** argv) { flags.entry_point = "entry"; flags.debug_info_path_begin_marker = ""; + // Note that tfcompile.bzl's tf_library macro sets fast math flags as that is + // generally the preferred case. std::vector flag_list; AppendMainFlags(&flag_list, &flags); xla::AppendDebugOptionsFlags(&flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 28d922f9e3c..5ec0575ed77 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -251,7 +251,7 @@ cc_library( visibility = [":friends"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], "//conditions:default": [ "//tensorflow/core:graph", @@ -505,6 +505,7 @@ cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], hdrs = ["shape_inference.h"], + visibility = [":friends"], deps = [ ":shape_inference_helpers", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 174250f18bd..9f5723f4fa4 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2034,6 +2034,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "TensorArraySplitV3", "TensorArrayV3", "TensorArrayWriteV3", + "TensorListConcatV2", "TensorListElementShape", "TensorListFromTensor", "TensorListGather", @@ -2043,6 +2044,7 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "TensorListPushBack", "TensorListReserve", "TensorListSetItem", + "TensorListSplit", "TensorListStack", "TensorScatterAdd", "TensorScatterSub", diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index abb42aa1815..7842513331d 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -395,12 +395,11 @@ static void ShowXlaDeviceDeprecationWarning( if (absl::StrContains(compilation_device_name, "CPU") || absl::StrContains(compilation_device_name, "GPU")) { absl::call_once(once, [] { - LOG(WARNING) - << "XLA_GPU and XLA_CPU devices are deprecated and will be " - "removed in subsequent releases. Instead, use either " - "@tf.function(experimental_compile=True) for must-compile " - "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 " - "for auto-clustering best-effort compilation."; + LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be " + "removed in subsequent releases. Instead, use either " + "@tf.function(experimental_compile=True) for must-compile " + "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 " + "for auto-clustering best-effort compilation."; }); } } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 34ff0c55615..17e4226405a 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -180,12 +180,10 @@ class XlaAssignVariableOp : public OpKernel { data::MakeIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ data::AnonymousIteratorHandleOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \ - data::AnonymousIteratorHandleOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \ - data::DeleteIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \ + data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \ + data::DeleteIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc index 99046c0bd76..3cc68f2a1a4 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, } string message = absl::StrCat( "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def), ".\n"); + SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); absl::StrAppend(&message, "Uncompilable nodes:"); for (const auto& node_info : uncompilable_node_info) { string node_message = diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index e0ec990462b..8c24f182f5c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs( se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. - arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); - arg_buffers_.resize(kernel->xla_input_shapes.size()); - arg_ptrs_ = std::vector(arg_buffers_.size()); + arg_ptrs_ = std::vector(kernel->xla_input_shapes.size()); // Pass remaining parameters. const Tensor* t; @@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs( << " not the same as on-host shape " << xla::ShapeUtil::HumanStringWithLayout(shape); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_[i] = absl::make_unique( + arg_buffers_.emplace_back( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); - arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); - arg_ptrs_[i] = arg_buffers_[i].get(); + arg_buffers_.back().set_buffer(dmem, /*index=*/{}); + arg_ptrs_[i] = &arg_buffers_.back(); } } } diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 511e0f1451a..cf68dcb7dd6 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -165,7 +165,7 @@ class XlaComputationLaunchContext { se::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; bool use_multiple_streams_; - std::vector> arg_buffers_; + std::deque arg_buffers_; std::vector arg_ptrs_; }; diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 4a4d566f163..c4472e1185c 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -77,10 +77,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes", - "//tensorflow/compiler/mlir/tfrt:lower_tf_to_tfd_alwayslink", - "//tensorflow/compiler/mlir/tfrt:runtime_fallback_opdefs_alwayslink", - "//tensorflow/compiler/mlir/tfrt:tf_legalize_to_tfrt", - "//tensorflow/compiler/mlir/tfrt:tf_to_corert", ], ) @@ -108,6 +104,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", ], alwayslink = 1, @@ -152,7 +149,6 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_registration", "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", - "//tensorflow/compiler/mlir/tfrt:compatibility_analysis", "//tensorflow/compiler/mlir/xla:xla_mlir_translate", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index d907d28b2c7..c60c0c0edbf 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -31,7 +31,7 @@ filegroup( "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -216,13 +216,13 @@ cc_library( "ir/tfl_ops.h", "transforms/passes.h", "utils/attribute_utils.h", - "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", "//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:support", @@ -260,6 +260,25 @@ cc_library( ], ) +cc_library( + name = "tftext_utils", + srcs = [ + "utils/tftext_utils.cc", + ], + hdrs = [ + "utils/tftext_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "stateful_ops_utils", srcs = [ @@ -320,6 +339,7 @@ cc_library( ":lstm_utils", ":stateful_ops_utils", ":tensorflow_lite", + ":tftext_utils", ":validators", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", @@ -695,9 +715,9 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", "@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index e9192388070..df84b028f63 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -799,11 +799,6 @@ Optional Translator::CreateFlexOpCustomOptions( Optional Translator::CreateCustomOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); return builder_.CreateVector(flex_builder->GetBuffer()); } @@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { auto flex_builder = absl::make_unique(); size_t map_start = flex_builder->StartMap(); - for (const auto& pair : node_def.attr()) { + using Item = std::pair; + std::vector attrs(node_def.attr().begin(), node_def.attr().end()); + std::sort(attrs.begin(), attrs.end(), + [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; }); + for (const Item& pair : attrs) { const char* key = pair.first.c_str(); - const auto& attr = pair.second; + const ::tensorflow::AttrValue& attr = pair.second; switch (attr.value_case()) { case ::tensorflow::AttrValue::kS: flex_builder->String(key, attr.s()); @@ -1020,7 +1019,7 @@ Optional> Translator::BuildOperator( if (!inst->getMutableAttrDict().getAttrs().empty()) { os << " {"; bool first = true; - for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) { + for (auto& named_attr : inst->getAttrDictionary()) { os << (!first ? ", " : ""); first = false; named_attr.first.print(os); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 3dcfe71770b..0c384ebf9f3 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -1966,9 +1966,9 @@ OpFoldResult TransposeOp::fold(ArrayRef operands) { } static LogicalResult Verify(TransposeOp op) { - auto input_type = op.x().getType().cast(); + auto input_type = op.input().getType().cast(); auto perm_type = op.perm().getType().cast(); - auto output_type = op.y().getType().cast(); + auto output_type = op.output().getType().cast(); if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (perm_type.getNumElements() != input_type.getRank()) { return op.emitOpError( diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 0e6a3db1f1b..c7a1504c3b7 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -27,7 +27,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index d103c07b986..48bc68e5c95 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -20,7 +20,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/LoopLikeInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" include "tensorflow/compiler/mlir/lite/quantization/quantization.td" @@ -161,7 +161,6 @@ class TFL_VariadicTensorOf allowedRuntimeTypes, Variadic>, TFL_RuntimeType>>; -def TFL_Uint8 : UI<8>; def TFL_Int32Or64 : SignlessIntOfWidths<[32, 64]>; def TFL_BoolTensor : TFL_TensorOf<[I1]>; @@ -228,10 +227,14 @@ class TFL_OperandHasAtleastRank : class TFL_OperandRankEquals1DimOfOperand : PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size", - CPred<"$_op.getOperand(" # x # - ").getType().cast().getRank() == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[0]">>; + Or<[TFL_OperandIsUnrankedPred, + TFL_OperandIsUnrankedPred, + CPred<"!$_op.getOperand(" # y # + ").getType().cast().hasStaticShape()">, + CPred<"$_op.getOperand(" # x # + ").getType().cast().getRank() == " + "$_op.getOperand(" # y # + ").getType().cast().getShape()[0]">]>>; class TFL_Operand0DOr1ElementTensor : PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element", @@ -247,7 +250,22 @@ class TFL_TFTypesWithSameBits : Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; -class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo : +class TFL_TFOperandTypesWithSameBits : + And<[ + Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa()">, + CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>, + Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; + +class TFL_OperandIsNoneOrHasRank : + PredOpTrait<"operand " # n # " is " # m # "-D", + Or<[ + CPred<"$_op.getOperand(" # n # ").getType().isa()">, + TFL_OperandIsUnrankedPred, + CPred<"$_op.getOperand(" # n # + ").getType().cast().getRank() == " # m>]>>; + +class TFL_OperandIsNoneOrHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[ CPred<"$_op.getOperand(" # n # ").getType().isa()">, @@ -255,13 +273,13 @@ class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo : CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() <= " # m>]>>; -class TFL_OperandHasRankLessThanOrEqualTo : +class TFL_OperandHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() <= " # m>]>>; -class TFL_OperandHasRankGreaterThanOrEqualTo : +class TFL_OperandHasRankAtLeast : PredOpTrait<"operand " # n # " is at least " # m # "-D", Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # @@ -275,17 +293,33 @@ class TFL_OperandHasRankRange : "getRank() <= " # y>]>>; def TFL_FloatNonNegative : AttrConstraint< - CPred<"!$_self.cast().getValue().isNegative()">, + CPred<"$_self.isa() && " + "!$_self.cast().getValue().isNegative()">, "whose value is non-negative">; -def TFL_BoolTrue: AttrConstraint< - CPred<"$_self.cast().getValue()">, +def TFL_BoolTrue : AttrConstraint< + CPred<"$_self.isa() && $_self.cast().getValue()">, "whose value is true">; -def TFL_BoolFalse: AttrConstraint< - CPred<"!$_self.cast().getValue()">, +def TFL_BoolFalse : AttrConstraint< + CPred<"$_self.isa() && !$_self.cast().getValue()">, "whose value is false">; +class TFL_StringEqualsTo : AttrConstraint< + CPred<"$_self.cast().getValue() == \"" # value # "\"">, + "whose value equals to '" # value # "'">; + +// Ensures the array attribute's size is within the given maximum size. +class TFL_ArrayMaxCount : AttrConstraint< + CPred<"$_self.isa() && $_self.cast().size() <= " # n>, + "whose size is at most " # n>; + +// Ensures the given integer attribute has the given value. +class TFL_IntEqualsTo : AttrConstraint< + CPred<"$_self.isa() && " + "$_self.cast().getInt() == " # n>, + "whose value is " # n>; + // This is a quantization-aware version of TCresVTEtIsSameAsOp class TFL_TCresVTEtIsSameAsOp : And<[ TCOpResIsShapedTypePred, @@ -300,6 +334,18 @@ class TFL_TCresVTEtIsSameAsOp : And<[ "quant::QuantizedType::castToStorageType(" "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>; +// This is a quantization-aware version of TCresVTEtIsSameAsOp +class TFL_TCopVTEtAreSameAt : Or<[ + TCopVTEtAreSameAt<[i, j]>, + TFL_TFOperandTypesWithSameBits, + And<[ + SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))", + quant_QuantizedType.predicate>, + CPred<"quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(" # i # "))) == " + "quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>; + //===----------------------------------------------------------------------===// // TFL op common constraints. //===----------------------------------------------------------------------===// @@ -395,9 +441,9 @@ class TFL_ConvOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, TFL_TensorOf<[F32, QI8, QUI8]>:$filter, - TFL_TensorOfOrNone<[F32, I32]>:$bias, + TFL_TensorOfOrNone<[F32, I32, I64]>:$bias, I32Attr:$dilation_h_factor, I32Attr:$dilation_w_factor, TFL_AFAttr:$fused_activation_function, @@ -406,7 +452,7 @@ class TFL_ConvOp : I32Attr:$stride_w ); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output); let hasOptions = 0b1; } @@ -437,7 +483,10 @@ an output element, this operation computes \\(y = |x|\\). def TFL_AddOp : TFL_Op<"add", [ TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, - ResultsBroadcastableShape, NoSideEffect, Commutative, TFL_GpuTargetOp]> { + ResultsBroadcastableShape, + NoSideEffect, + Commutative, + TFL_GpuTargetOp]> { let summary = "Addition operator"; let description = [{ @@ -505,8 +554,14 @@ retained with length 1. let customOption = "ReducerOptions"; } -def TFL_TransposeConvOp: - TFL_Op<"transpose_conv", [NoSideEffect, TFL_GpuTargetOp]> { +def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [ + NoSideEffect, + TFL_OperandHasRank<0, 1>, + TFL_OperandHasRank<1, 4>, + TFL_OperandHasRank<2, 4>, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 2>>, + TFL_GpuTargetOp]> { let summary = "Transpose convolution operator"; let description = [{ @@ -514,16 +569,16 @@ def TFL_TransposeConvOp: }]; let arguments = (ins - TFL_1DTensorOf<[I32]>:$output_shape, - TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$weights, - TFL_TensorOf<[F32, TFL_Uint8, QI8, QUI8]>:$input, - TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, + TFL_I32Tensor:$output_shape, + TFL_TensorOf<[F32, QI8, QUI8]>:$weights, + TFL_TensorOf<[F32, QI8, QUI8]>:$input, + TFL_TensorOfOrNone<[F32, QI32]>:$bias, TFL_PaddingAttr:$padding, - I32Attr:$stride_h, - I32Attr:$stride_w + Confined:$stride_h, + Confined:$stride_w ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); let hasOptions = 1; @@ -565,7 +620,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> { }]; let arguments = ( - ins TFL_TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I32, I8, UI8, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$dim ); @@ -595,7 +650,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { }]; let arguments = ( - ins TFL_TensorOf<[F32, I32, I8, TFL_Uint8, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I32, I8, UI8, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$dim ); @@ -642,14 +697,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", let arguments = ( ins TFL_VariadicTensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$values, + [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$values, I32Attr:$axis, TFL_AFAttr:$fused_activation_function ); let results = (outs TFL_TensorOf< - [F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output + [F32, I64, I32, I16, I8, QI8, QUI8, UI8]>:$output ); let hasOptions = 1; @@ -713,7 +768,8 @@ def SparsityParameterAttr : StructAttr<"SparsityParameterAttr", TFL_Dialect, [ let storageType = [{ TFL::SparsityParameterAttr }]; } -def TFL_SparseConstOp : Op { let summary = "Sparse constant pseudo op."; @@ -924,12 +980,12 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$params, + TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$params, TFL_I32OrI64Tensor:$indices ); let results = (outs - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8, TFL_Str]>:$output + TFL_TensorOf<[F32, I8, I64, I32, UI8, TFL_Str]>:$output ); } @@ -948,12 +1004,12 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [ let arguments = (ins TFL_TensorOf<[I32]>:$indices, - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$updates, + TFL_TensorOf<[F32, I8, I64, I32, UI8]>:$updates, TFL_1DTensorOf<[I32]>:$shape ); let results = (outs - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I8, I64, I32, UI8]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -963,7 +1019,11 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [ // Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait. def TFL_LessEqualOp : TFL_Op<"less_equal", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Less_equal operator"; let description = [{ @@ -971,8 +1031,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs, - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -985,9 +1045,12 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ let hasOptions = 0; } -def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", - [NoSideEffect]> { - let summary = "Local Response Normalization."; +def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [ + TFL_OperandHasRank<0, 4>, + SameOperandsAndResultShape, + SameOperandsAndResultType, + NoSideEffect]> { + let summary = "Local Response Normalization."; let description = [{ The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last @@ -1004,7 +1067,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag }]; let arguments = (ins - TFL_TensorOf<[F32, QI8, QUI8]>:$input, + TFL_FpTensor:$input, I32Attr:$radius, F32Attr:$bias, F32Attr:$alpha, @@ -1012,7 +1075,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag ); let results = (outs - TFL_TensorOf<[F32, QI8, QUI8]>:$output + TFL_FpTensor:$output ); let hasOptions = 1; @@ -1048,7 +1111,7 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [ NoSideEffect, TFL_OperandHasAtleastRank<0, 1>, PredOpTrait<"operand and result must have the same element type", - TCresVTEtIsSameAsOp<0, 0>>]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = [{ Returns a tensor with the provided diagonal and everything else padded with zeros. }]; @@ -1061,17 +1124,21 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QUI8, QI8, TFL_Quint8]>:$diagonal ); let results = (outs - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QUI8, QI8, TFL_Quint8]>:$output ); let hasOptions = 0; } -def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [NoSideEffect]> { +def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [ + TFL_OperandHasAtleastRank<0, 2>, + PredOpTrait<"input and result must have the same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect]> { let summary = [{ Returns a batched matrix tensor with new batched diagonal values. }]; @@ -1083,12 +1150,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`. }]; let arguments = (ins - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$input, - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$diagonal + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input, + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal ); let results = (outs - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$output + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result ); let hasOptions = 0; @@ -1206,7 +1273,12 @@ larger than 0. } def TFL_NotEqualOp : TFL_Op<"not_equal", [ - ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> { + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + BinaryOpSameElementTypeConstraint, + ResultsBroadcastableShape, + Commutative, + NoSideEffect, + NoQuantizableResult]> { let summary = "Not_equal operator"; let description = [{ @@ -1214,8 +1286,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs, + TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -1234,8 +1306,10 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ } def TFL_DivOp : TFL_Op<"div", [ - // TODO(fengliuai): NoQuantizableResult is only correct for int8 - // quantization. update to handle Uint8 quantization. + // TODO(fengliuai): NoQuantizableResult is only correct for int8 + // quantization. update to handle Uint8 quantization. + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult, @@ -1248,10 +1322,10 @@ def TFL_DivOp : TFL_Op<"div", [ let arguments = ( ins TFL_TensorOf<[F32, I32, QUI8]>:$lhs, - TFL_TensorOf<[F32, I32, TFL_Uint8]>:$rhs, + TFL_TensorOf<[F32, I32, QUI8]>:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs TFL_TensorOf<[F32, I32, TFL_Uint8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, QUI8]>:$output); let builders = [TFL_FusedBroadcastableBinaryBuilder]; @@ -1284,7 +1358,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", PredOpTrait<"value and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 1>>, TFL_OperandHasRank<0, 1>, - TFL_OperandHasRankGreaterThanOrEqualTo<1, 2> + TFL_OperandHasRankAtLeast<1, 2> ]> { let summary = "Embedding lookup operator"; @@ -1294,10 +1368,10 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", let arguments = (ins TFL_TensorOf<[I32]>:$lookup, - TFL_TensorOf<[F32, I8, TFL_Uint8]>:$value + TFL_TensorOf<[F32, I8, UI8]>:$value ); - let results = (outs TFL_TensorOf<[F32, I8, TFL_Uint8]>:$output); + let results = (outs TFL_TensorOf<[F32, I8, UI8]>:$output); } def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, @@ -1313,8 +1387,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, ResultsBroadcastableShape, let arguments = ( ins - TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$x, - TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, TFL_Uint8, TFL_Str]>:$y + TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$x, + TFL_TensorOf<[I1, F32, I32, I64, QI8, QUI8, UI8, TFL_Str]>:$y ); let results = (outs TFL_BoolTensor:$output); @@ -1394,7 +1468,7 @@ def TFL_SqueezeOp: TFL_Op<"squeeze", [NoSideEffect, Given a tensor `input`, this operation returns a tensor of the same type with all dimensions of size 1 removed. If you don't want to remove all size 1 dimensions, you can remove specific size 1 dimensions by specifying -`axis`. +`squeeze_dims`. For example: @@ -1413,7 +1487,7 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] let arguments = (ins AnyTensor:$input, - DefaultValuedAttr:$squeeze_dims + Confined, [TFL_ArrayMaxCount<8>]>:$squeeze_dims ); let results = (outs @@ -1502,7 +1576,11 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [ } def TFL_GreaterOp : TFL_Op<"greater", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Greater operator"; let description = [{ @@ -1510,10 +1588,10 @@ def TFL_GreaterOp : TFL_Op<"greater", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs); - let results = (outs AnyTensor:$output); + let results = (outs TFL_BoolTensor:$output); let builders = [TFL_ComparisonBinaryBuilder]; @@ -1522,9 +1600,12 @@ def TFL_GreaterOp : TFL_Op<"greater", [ let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; } -def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, - SameOperandsAndResultShape, - TFL_GpuTargetOp]> { +def TFL_HardSwishOp: TFL_Op<"hard_swish", [ + NoSideEffect, + SameOperandsAndResultShape, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_GpuTargetOp]> { let summary = "Hardswish activation function."; let description = [{ Computes hard-swish activation function @@ -1534,7 +1615,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input); - let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$out); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output); let hasOptions = 0; } @@ -1563,29 +1644,35 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect, let customOption = "L2NormOptions"; } -def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [ + SameOperandsAndResultShape, + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Leaky Relu operator"; - // TODO(jpienaar): Add type restriction. This op is only defined for - // restricted (floating point) types. let description = [{ Element-wise Leaky ReLU operator x -> x >= 0 ? x : (alpha * x) }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input, // Slope of the activation function at x < 0. F32Attr:$alpha ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let hasOptions = 0b1; } def TFL_LessOp : TFL_Op<"less", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Less operator"; let description = [{ @@ -1593,8 +1680,8 @@ def TFL_LessOp : TFL_Op<"less", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -1655,11 +1742,14 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { def TFL_LogisticOp: TFL_Op<"logistic", [ NoSideEffect, + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultShape, // zero_point = 0 // scale = 1. / (max_value + 1) FixedResultScale>, FixedResultScale>, + FixedOutputRangeInterface, TFL_GpuTargetOp]> { let summary = "Logistic operator"; @@ -1667,9 +1757,39 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ Computes element-wise Sigmoid of input }]; - let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y); + + let extraClassDeclaration = [{ + // FixedOutputRangeInterface: + quant::UniformQuantizedType GetFixedOutputRange( + bool is_signed, int bit_width) { + auto result_type = y().getType().cast(); + if (!result_type.getElementType().isa()) return {}; + Builder builder(result_type.getContext()); + + // Only support 8-bits + if (bit_width != 8) return {}; + IntegerType storage_type = builder.getIntegerType(bit_width); + + double scale = 1.0 / 256; + int64_t zero_point, storage_min, storage_max; + if (is_signed) { + zero_point = -128; + storage_min = -128; + storage_max = 127; + } else { + zero_point = 0; + storage_min = 0; + storage_max = 255; + } + + return quant::UniformQuantizedType::getChecked( + is_signed, storage_type, result_type.getElementType(), scale, + zero_point, storage_min, storage_max, builder.getUnknownLoc()); + } + }]; } def TFL_LogOp: TFL_Op<"log", [ @@ -1690,10 +1810,11 @@ def TFL_LogOp: TFL_Op<"log", [ let hasFolder = 1; } -// TODO(b/130643170): Adds some constraint for the input/output element types. def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ NoSideEffect, SameOperandsAndResultShape, + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, // zero_point = max_value // scale = -log_softmax_output_min / (max_value + 1) FixedResultScale>, @@ -1706,9 +1827,9 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ input - log(reduce_sum(exp(input), dim)) }]; - let arguments = (ins AnyTensor:$input); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -1727,6 +1848,9 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and " TFL_TCresVTEtIsSameAsOp<0, 0>]>>; def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ + TFL_OperandHasRank<0, 4>, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, MaxPoolOperandAndResultConstraints, SameOperandsAndResultsScale, @@ -1741,7 +1865,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$input, TFL_PaddingAttr:$padding, I32Attr:$stride_w, I32Attr:$stride_h, @@ -1750,7 +1874,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ TFL_AFAttr:$fused_activation_function ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$output); let hasOptions = 1; @@ -1782,7 +1906,11 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ let hasOptions = 0; } -def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> { +def TFL_MeanOp : TFL_Op<"mean", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + TFL_GpuTargetOp]> { let summary = "Mean operator"; let description = [{ @@ -1794,13 +1922,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input, TFL_TensorOf<[I32, I64]>:$axis, BoolAttr:$keep_dims ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output); + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -1821,20 +1949,23 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { let arguments = (ins TFL_TensorOf<[I32, I64]>:$indices, TFL_I32Tensor:$depth, - TFL_TensorOf<[F32, I32, I64, I1]>:$on_value, - TFL_TensorOf<[F32, I32, I64, I1]>:$off_value, + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$on_value, + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$off_value, I32Attr:$axis ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I1]>:$output + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$output ); let hasOptions = 1; } -def TFL_RoundOp: TFL_Op<"round", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_RoundOp: TFL_Op<"round", [ + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultType]> { let summary = "Round operator"; let description = [{ @@ -1842,16 +1973,23 @@ Rounds the values of a tensor to the nearest integer, element-wise. }]; let arguments = (ins - TFL_TensorOf<[F32]>:$x + TFL_FpTensor:$x ); let results = (outs - TFL_TensorOf<[F32]>:$y + TFL_FpTensor:$y ); } def TFL_SliceOp : TFL_Op<"slice", [ - NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> { + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, + TFL_OperandHasRankAtMost<1, 1>, + TFL_OperandHasRankAtMost<2, 1>, + TFL_GpuTargetOp]> { let summary = "Return a slice from 'input'."; let description = [{ @@ -1869,13 +2007,13 @@ equivalent to setting: }]; let arguments = (ins - AnyTensor:$input, + TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$size ); let results = (outs - AnyTensor:$output + TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -1883,7 +2021,11 @@ equivalent to setting: let hasCanonicalizer = 1; } -def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { +def TFL_SumOp: TFL_Op<"sum", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect]> { + let summary = "Sum operator"; let description = [{ @@ -1891,19 +2033,22 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { }]; let arguments = (ins - AnyTensor:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs AnyTensor); + let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; } def TFL_ReduceMinOp: TFL_Op<"reduce_min", [ - NoSideEffect, SameOperandsAndResultsScale]> { + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale]> { let summary = "Min-reduction operator"; let description = [{ @@ -1911,19 +2056,23 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [ }]; let arguments = (ins - AnyTensor:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs AnyTensor); + let results = (outs + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; } def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [ - NoSideEffect, SameOperandsAndResultsScale]> { + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale]> { let summary = "Max-reduction operator"; let description = [{ @@ -1931,18 +2080,22 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [ }]; let arguments = (ins - AnyTensor:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs AnyTensor); + let results = (outs + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; } -def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { +def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect]> { let summary = "Prod-reduction operator"; let description = [{ @@ -1950,12 +2103,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs AnyTensor); + let results = (outs + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -1986,12 +2140,13 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ let hasOptions = 0; } -def TFL_MulOp : TFL_Op<"mul", [ResultsBroadcastableShape, - NoSideEffect, - Commutative, - BinaryOpSameElementTypeConstraint, - TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, - TFL_GpuTargetOp]> { +def TFL_MulOp : TFL_Op<"mul", [ + ResultsBroadcastableShape, + NoSideEffect, + Commutative, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, + TFL_GpuTargetOp]> { let summary = "Multiplication operator"; let description = [{ @@ -2032,7 +2187,11 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> { let hasFolder = 1; } -def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { +def TFL_PackOp : TFL_Op<"pack", [ + PredOpTrait<"values and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale]> { let summary = "Packs a list of tensors along a dimension into one tensor"; let description = [{ @@ -2063,14 +2222,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { }]; let arguments = (ins - TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$values, + TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values, - I32Attr:$values_count, + Confined:$values_count, I32Attr:$axis ); let results = (outs - TFL_TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -2081,8 +2240,11 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { } def TFL_PadOp : TFL_Op<"pad", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, TFL_OperandHasRank<1, 2>, TFL_OperandRankEquals1DimOfOperand<0, 1>, TFL_GpuTargetOp]> { @@ -2113,22 +2275,25 @@ def TFL_PadOp : TFL_Op<"pad", [ ``` }]; - let arguments = (ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$padding); - let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } def TFL_PadV2Op : TFL_Op<"padv2", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, TFL_OperandHasRank<1, 2>, TFL_OperandHasRank<2, 0>, TFL_OperandRankEquals1DimOfOperand<0, 1>, PredOpTrait<"input and constant value operands must have same element type", - TCopVTEtAreSameAt<[0, 2]>>]> { + TFL_TCopVTEtAreSameAt<0, 2>>]> { let summary = "Padding operator v2"; let description = [{ @@ -2159,11 +2324,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$padding, - TFL_TensorOf<[F32, I8, I32, I64]>:$constant_values); + TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$constant_values); - let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -2191,26 +2356,29 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, - TFL_GpuTargetOp, - SameOperandsAndResultsScale]> { +def TFL_PReluOp : TFL_Op<"prelu", [ + NoSideEffect, + ResultsBroadcastableShape, + TFL_GpuTargetOp, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + BinaryOpSameElementTypeConstraint, + PredOpTrait<"input and output must have the same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Parameterized Relu operator"; let description = [{ Parameterized Relu operator x -> x >= 0 ? x : (alpha * x) where alpha is a trainable tensor. - alpha should have one less rank than the input as it doesn't have the batch - dimension, and the other dimensions either should be the same size as input - or size 1, where it is broadcasted in the second case. + input and alpha should be the same size as input or be broadcastable. }]; let arguments = ( - ins TFL_TensorOf<[F32, QUI8]>:$input, - TFL_TensorOf<[F32, QUI8]>:$alpha + ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$alpha ); - let results = (outs TFL_TensorOf<[F32, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); let verifier = [{ return Verify(*this); }]; } @@ -2228,10 +2396,13 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> { let hasFolder = 1; } -def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { +def TFL_ReluOp: TFL_Op<"relu", [ + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Relu operator"; let description = [{ @@ -2239,9 +2410,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, x -> max(0, x) }]; - let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); - let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2255,10 +2426,13 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, ]; } -def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { +def TFL_Relu6Op: TFL_Op<"relu6", [ + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Relu6 operator"; let description = [{ @@ -2266,9 +2440,9 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, x -> max(0, min(6, x)) }]; - let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); - let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2282,9 +2456,12 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, ]; } -def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { +def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale]> { let summary = "Relu1 operator"; let description = [{ @@ -2292,9 +2469,9 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, x -> max(-1, min(1, x)) }]; - let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); - let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2326,7 +2503,11 @@ def TFL_ReshapeOp: TFL_Op<"reshape", [ let hasFolder = 1; } -def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [NoSideEffect]> { +def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + TFL_OperandHasRank<1, 1>]> { let summary = "Reverses variable length slices."; let description = [{ @@ -2343,15 +2524,15 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension }]; let arguments = (ins - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$seq_lengths, - I32Attr:$seq_dim, - I32Attr:$batch_dim + Confined:$seq_dim, + Confined:$batch_dim ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$output ); let hasOptions = 1; @@ -2359,6 +2540,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType, + SameOperandsAndResultShape, NoQuantizableResult, TFL_GpuTargetOp]> { let summary = "Reciprocal of square root operator"; @@ -2367,9 +2549,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, Computes element-wise reverse square root of input }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TFL_FpTensor:$x); - let results = (outs AnyTensor:$y); + let results = (outs TFL_FpTensor:$y); let hasFolder = 1; } @@ -2383,7 +2565,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> { let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[I32, I64]>:$output); DerivedTypeAttr out_type = DerivedTypeAttr<[{ return getResult().getType().cast().getElementType(); @@ -2392,9 +2574,11 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> { let hasOptions = 1; } -// TODO(jpienaar): Flesh this out. -def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>, - TFL_OperandHasRank<1, 0>, TFL_OperandHasRank<2, 0>, +def TFL_RangeOp: TFL_Op<"range", [ + NoSideEffect, + TFL_OperandHasRank<0, 0>, + TFL_OperandHasRank<1, 0>, + TFL_OperandHasRank<2, 0>, PredOpTrait<"operands and output must have same element type", And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>, TCresVTEtIsSameAsOp<0, 2>]>>]> { @@ -2406,17 +2590,20 @@ def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>, }]; let arguments = (ins - AnyTensor:$start, - AnyTensor:$limit, - AnyTensor:$delta); + TFL_TensorOf<[I32, F32]>:$start, + TFL_TensorOf<[I32, F32]>:$limit, + TFL_TensorOf<[I32, F32]>:$delta); - let results = (outs AnyTensor:$result); + let results = (outs TFL_TensorOf<[I32, F32]>:$result); let hasFolder = 1; } -def TFL_ReverseV2Op: TFL_Op<"reverse_v2", - [NoSideEffect, TFL_OperandHasRank<1,1>]> { +def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + TFL_OperandHasRank<1, 1>]> { let summary = "ReverseV2 Operator"; let description = [{ @@ -2438,21 +2625,21 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", let arguments = ( ins - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input, - TFL_TensorOf<[I32, I64]>:$axis + TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input, + TFL_I32Tensor:$axis ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output - ); + TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output); } // Select has many instances in TF models where one or more of its operands // are unranked. Therefore, we skip adding shape constraints here. -def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, +def TFL_SelectOp : TFL_Op<"select", [ + NoSideEffect, PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, PredOpTrait<"operands and result have same element type", - TCresVTEtIsSameAsOp<0, 1>>]> { + TFL_TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "Select operator"; let description = [{ @@ -2465,9 +2652,11 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let arguments = (ins TFL_BoolTensor:$condition, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); - let results = (outs AnyTensor:$output); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y); + + let results = (outs + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); // TODO(jpienaar): autogenerate this. let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " @@ -2481,7 +2670,12 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let hasOptions = 1; } -def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> { +def TFL_SelectV2Op : TFL_Op<"select_v2", [ + NoSideEffect, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>, + PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, + PredOpTrait<"operands and result have same element type", + TFL_TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "SelectV2 operator"; let description = [{ @@ -2494,9 +2688,11 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> { let arguments = (ins TFL_BoolTensor:$condition, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); - let results = (outs AnyTensor:$output); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y); + + let results = (outs + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " "Value cond, Value x, Value y", @@ -2525,9 +2721,11 @@ def TFL_SinOp: TFL_Op<"sin", [ let hasFolder = 1; } -// TODO(b/130643170): Adds some constraint for the input/output element types. def TFL_SoftmaxOp : TFL_Op<"softmax", [ NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRankRange<0, 1, 4>, SameOperandsAndResultShape, // zero_point = 0 // scale = 1. / (max_value + 1) @@ -2543,11 +2741,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, F32Attr:$beta ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -2590,7 +2788,11 @@ def TFL_SquareOp: TFL_Op<"square", [ let hasFolder = 1; } -def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { +def TFL_SubOp : TFL_Op<"sub", [ + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 5>, + NoSideEffect]> { let summary = "Subtraction operator"; let description = [{ @@ -2598,11 +2800,11 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs, + ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs, + TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs, TFL_AFAttr:$fused_activation_function); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output); let hasFolder = 1; @@ -2618,6 +2820,8 @@ def TFL_SubOp : TFL_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { // TODO(jpienaar): Expand the kernel implementation to support all types besides // I32 and F32. def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + SameOperandsAndResultElementType, ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult, @@ -2629,10 +2833,10 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32]>:$lhs, + TFL_TensorOf<[F32, I32]>:$rhs); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, I32]>:$output); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -2644,6 +2848,8 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ def TFL_TanhOp: TFL_Op<"tanh", [ NoSideEffect, SameOperandsAndResultShape, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, // central_value = min_value / 2 + (max_value - 1) / 2 + 1 // zero_point = central_value // scale = 1. / (central_value - min_value) @@ -2656,9 +2862,9 @@ def TFL_TanhOp: TFL_Op<"tanh", [ Computes element-wise Hyperbolic tangent of input }]; - let arguments = (ins TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input); - let results = (outs TFL_TensorOf<[F32, I16, I8, QI8, QUI8, QI16, QUI16, TFL_Uint8]>:$y); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$output); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2672,9 +2878,11 @@ def TFL_TanhOp: TFL_Op<"tanh", [ ]; } -def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale, - PredOpTrait<"resultant element type needs to match first operand type", - TFL_TCresVTEtIsSameAsOp<0,0>>]> { +def TFL_TileOp: TFL_Op<"tile", [ + NoSideEffect, + SameOperandsAndResultsScale, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Tile operator."; let description = [{ Constructs a tensor by tiling a given tensor. @@ -2687,11 +2895,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale, }]; let arguments = (ins - TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$input, + TFL_TensorOf<[F32, I1, I32, I64, UI8, QUI8, TFL_Str]>:$input, TFL_I32OrI64Tensor:$multiples); let results = (outs - TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$output); + TFL_TensorOf<[F32, I1, I32, I64, UI8, QUI8, TFL_Str]>:$output); let hasOptions = 0; } @@ -2699,9 +2907,13 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale, // TODO(jpienaar): Maybe make it accept any single element tensor as `k`. // TODO(jpienaar): Check that input has one or more dimensions. // TODO(jpienaar): Check that k is less or equal the internal dimension -def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, +def TFL_TopKV2Op: TFL_Op<"topk_v2", [ + NoSideEffect, + TFL_OperandHasRankAtLeast<0, 1>, + TFL_OperandHasRank<1, 0>, PredOpTrait<"result and input element type match", - TCresVTEtIsSameAsOp<0,0>>, SameOperandsAndResultsScale]> { + TFL_TCresVTEtIsSameAsOp<0,0>>, + SameOperandsAndResultsScale]> { let summary = "TopK operator"; let description = [{ @@ -2711,11 +2923,11 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$input, TFL_I32Tensor:$k); let results = (outs - TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QI8, QUI8]>:$values, + TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$values, TFL_I32Tensor:$indices); let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " @@ -2725,29 +2937,27 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, let hasOptions = 1; } -def TFL_TransposeOp : TFL_Op<"transpose", - [NoSideEffect, - TFL_OperandHasRank<1,1>, - // TODO(jpienaar): these are only true dynamically, change so that it works - // with unknowns. - // TFL_OperandRankEquals1DimOfOperand<0, 1>, - PredOpTrait<"input and output must have same element type", - TCresVTEtIsSameAsOp<0, 0>>, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { +def TFL_TransposeOp : TFL_Op<"transpose", [ + NoSideEffect, + TFL_OperandHasRankAtMost<0, 5>, + TFL_OperandHasRank<1, 1>, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Transpose operator"; let description = [{ Returns the Transpose of x }]; - let arguments = ( - ins AnyTensor:$x, + let arguments = (ins + TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input, TFL_TensorOf<[I32]>:$perm ); let results = (outs - AnyTensor:$y + TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -2755,7 +2965,10 @@ def TFL_TransposeOp : TFL_Op<"transpose", let hasFolder = 1; } -def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> { +def TFL_UnpackOp : TFL_Op<"unpack", [ + NoSideEffect, + SameOperandsAndResultElementType, + SameOperandsAndResultsScale]> { let summary = "Unpacks a tensor along a dimension into multiple tensors"; let description = [{ @@ -2776,14 +2989,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> }]; let arguments = (ins - TFL_TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$input, I32Attr:$num, I32Attr:$axis ); let results = (outs - TFL_VariadicTensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$outputs + TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -2791,16 +3004,19 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> let hasOptions = 1; } -def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [NoSideEffect]> { +def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect]> { let summary = "ZerosLike operator"; let description = [{ Returns a tensor of zeros with the same shape and type as the input tensor. }]; - let arguments = (ins AnyTensor:$input); + let arguments = (ins TFL_TensorOf<[I64, I32, F32]>:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[I64, I32, F32]>:$output); let hasOptions = 1; } @@ -2834,8 +3050,9 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankRange<0, 3, 4>, PredOpTrait<"input and output must have same element type", - TCresVTEtIsSameAsOp<0, 0>> + TFL_TCresVTEtIsSameAsOp<0, 0>> ]> { let summary = "SpaceToBatchNd operator"; @@ -2844,13 +3061,13 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, - TFL_TensorOf<[I32]>:$block_shape, - TFL_TensorOf<[I32]>:$paddings + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + TFL_I32Tensor:$block_shape, + TFL_I32Tensor:$paddings ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output ); } @@ -2858,7 +3075,8 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [ NoSideEffect, SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", - TCresVTEtIsSameAsOp<0, 0>>, + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRankAtMost<0, 4>, TFL_GpuTargetOp ]> { let summary = "SpaceToDepth operator"; @@ -2871,12 +3089,12 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input, - I32Attr:$block_size + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + Confined:$block_size ); let results = (outs - TFL_TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output ); let hasOptions = 1; @@ -2887,7 +3105,7 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [ SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_OperandHasRankLessThanOrEqualTo<0, 4> + TFL_OperandHasRankAtMost<0, 4> ]> { let summary = "DepthToSpace operator"; @@ -2901,12 +3119,12 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, TFL_Uint8, UI8, QI8, QUI8]>:$input, + TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, UI8, QI8, QUI8]>:$input, Confined:$block_size ); let results = (outs - TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, TFL_Uint8, UI8, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I8, I32, I64, TFL_Quint8, UI8, QI8, QUI8]>:$output ); let hasOptions = 1; @@ -2926,12 +3144,12 @@ def TFL_SplitOp : TFL_Op<"split", [ let arguments = (ins TFL_TensorOf<[I32]>:$split_dim, - TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, + TFL_TensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$value, Confined:$num_splits ); let results = (outs - TFL_VariadicTensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$outputs + TFL_VariadicTensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -2949,14 +3167,14 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale] }]; let arguments = (ins - TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value, + TFL_TensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$value, TFL_1DTensorOf<[I32], [I32]>:$size_splits, TFL_0DTensorOf<[I32], [I32]>:$split_dim, Confined:$num_splits ); let results = (outs - TFL_VariadicTensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$outputs + TFL_VariadicTensorOf<[F32, I16, I32, I64, I8, UI8, QI8, QUI8, QI16]>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -2965,7 +3183,12 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale] } def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [ - NoSideEffect, SameOperandsAndResultsScale]> { + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRank<0, 4>, + TFL_OperandHasRank<1, 1>, + SameOperandsAndResultsScale]> { let summary = "ResizeBilinear Op"; let description = [{ @@ -2973,23 +3196,26 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [ }]; let arguments = (ins - // TODO(ycling): Support quantized types. - TFL_TensorOf<[F32, I32, QI8, QUI8]>:$input, - TFL_TensorOf<[I32]>:$size, + TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, + TFL_I32Tensor:$size, BoolAttr:$align_corners, DefaultValuedAttr:$half_pixel_centers ); let results = (outs - TFL_TensorOf<[F32, QI8, QUI8]>:$output + TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output ); let hasOptions = 1; } -def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", - [NoSideEffect, - SameOperandsAndResultsScale]> { +def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRank<0, 4>, + TFL_OperandHasRank<1, 1>, + SameOperandsAndResultsScale]> { let summary = "ResizeNearestNeighbor Op"; let description = [{ @@ -2997,20 +3223,28 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", }]; let arguments = (ins - TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input, - TFL_TensorOf<[I32]>:$size, + TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input, + TFL_I32Tensor:$size, BoolAttr:$align_corners, DefaultValuedAttr:$half_pixel_centers ); let results = (outs - TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output + TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output ); let hasOptions = 1; } -def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [NoSideEffect]> { +def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [ + NoSideEffect, + PredOpTrait<"sparse_values and dense must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 2>>, + PredOpTrait<"default_value and dense must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 3>>, + TFL_OperandHasRankAtMost<0, 2>, + TFL_OperandHasRankAtMost<1, 1>, + TFL_OperandHasRankAtMost<2, 1>]> { let summary = "Converts a sparse representation into a dense tensor."; let description = [{ @@ -3038,21 +3272,24 @@ are checked during execution. let arguments = (ins TFL_I32OrI64Tensor:$sparse_indices, TFL_I32OrI64Tensor:$output_shape, - TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$sparse_values, - TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$default_value + TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$sparse_values, + TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$default_value ); let results = (outs - TFL_TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$dense + TFL_TensorOf<[I32, I64, I8, QI8, UI8, QUI8, TFL_Quint8, F32]>:$dense ); } -def TFL_StridedSliceOp: TFL_Op<"strided_slice", - [ +def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ NoSideEffect, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 5>, + TFL_OperandHasRank<1, 1>, + TFL_OperandHasRank<2, 1>, + TFL_OperandHasRank<3, 1>, TFL_GpuTargetOp ]> { let summary = "StridedSlice Op"; @@ -3062,20 +3299,20 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$input, - TFL_TensorOf<[I32]>:$begin, - TFL_TensorOf<[I32]>:$end, - TFL_TensorOf<[I32]>:$strides, + TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input, + TFL_I32Tensor:$begin, + TFL_I32Tensor:$end, + TFL_I32Tensor:$strides, I32Attr:$begin_mask, I32Attr:$end_mask, - I32Attr:$ellipsis_mask, - I32Attr:$new_axis_mask, + Confined]>:$ellipsis_mask, + Confined]>:$new_axis_mask, I32Attr:$shrink_axis_mask ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, I1, TFL_Quint8, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output ); let hasOptions = 1; @@ -3090,17 +3327,16 @@ def TFL_CastOp : TFL_Op<"cast", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex>]>:$input + TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex>]>:$input ); - let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F32, I1, I32, I64, TFL_Quint8, UI8, Complex>]>:$output); // TFLite's cast op does not utilize CastOptions, instead derives types // from the TfLiteTensors. let hasOptions = 0; } - def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> { let summary = "MirrorPad Operator. Pads a tensor with mirrored values."; @@ -3136,24 +3372,25 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ let hasOptions = 1; } -def TFL_UniqueOp: TFL_Op<"unique", [NoSideEffect]> { +def TFL_UniqueOp: TFL_Op<"unique", [ + TFL_OperandHasRank<0, 1>, + NoSideEffect]> { let summary = "Unique Op."; let description = [{ - This operation returns a tensor `y` containing all of the unique elements of `x` -sorted in the same order that they occur in `x`. This operation also returns a -tensor `idx` the same size as `x` that contains the index of each value of `x` -in the unique output `y`. In other words: + This operation returns a tensor `output` containing all of the unique elements +of `input` sorted in the same order that they occur in `input`. This operation +also returns a tensor `idx` the same size as `x` that contains the index of each +value of `input` in the unique output `output`. In other words: }]; let arguments = (ins - // TODO: add uint8 support after quantize support. - TFL_TensorOf<[I8, I16, I32, I64, F32]>:$input + TFL_TensorOf<[I8, QI8, UI8, QUI8, I16, QI16, I32, I64, F32]>:$input ); let results = (outs - TFL_TensorOf<[I8, I16, I32, I64, F32]>:$output, - TFL_TensorOf<[I32, I64]>:$idx + TFL_TensorOf<[I8, QI8, UI8, QUI8, I16, QI16, I32, I64, F32]>:$output, + TFL_I32OrI64Tensor:$idx ); DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ @@ -3224,7 +3461,7 @@ def TFL_QConstOp : Op:$output); let builders = [OpBuilder< "OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value", @@ -3253,7 +3490,7 @@ def TFL_SparseQConstOp : Op:$output); let builders = [OpBuilder< "OpBuilder &, OperationState &state, TypeAttr qtype, " @@ -3269,7 +3506,9 @@ def TFL_SparseQConstOp : Op { + FirstAttrDerivedResultType, + SameOperandsAndResultShape, + NoQuantizableResult]> { let summary = "Quantize operator"; let description = [{ @@ -3278,16 +3517,18 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [ }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input, TensorTypeAttr:$qtype ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[QI8, QUI8, QI16, TFL_Quint8]>:$output); } -def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect, - SameOperandsAndResultType, - NoQuantizableResult]> { +def TFL_DensifyOp: TFL_Op<"densify", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoQuantizableResult]> { let summary = "Densify operator"; let description = [{ @@ -3390,6 +3631,19 @@ def TFL_LSTMOp : LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, LstmResultConstraint, + TFL_OperandHasRank<2, 2>, // input_to_forget_weights + TFL_OperandHasRank<3, 2>, // input_to_cell_weights + TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights + TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights + TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights + TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights + TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights + TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights + TFL_OperandHasRank<13, 1>, // forget_gate_bias + TFL_OperandHasRank<14, 1>, // cell_gate_bias + TFL_OperandHasRank<15, 1>, // output_gate_bias + TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights + TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias TFL_StatefulOp]> { let summary = "The full lstm operator"; @@ -3416,23 +3670,23 @@ Ba et al. 'Layer Normalization' ins TFL_TensorOf<[F32, QI8]>:$input, // Weights - TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights, - TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights, - TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights, - TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_output_weights, // Recurrent weights - TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights, - TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights, - TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights, - TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights, // Cell weights - TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights, + TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_input_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights, + TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_forget_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_output_weights, // Bias TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias, @@ -3441,7 +3695,7 @@ Ba et al. 'Layer Normalization' TFL_TensorOf<[F32, QI32]>:$output_gate_bias, // Projection weight and bias - TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights, // Optional input TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias, @@ -3457,8 +3711,8 @@ Ba et al. 'Layer Normalization' // Attributes TFL_AFAttr:$fused_activation_function, - DefaultValuedAttr:$cell_clip, - DefaultValuedAttr:$proj_clip, + Confined, [TFL_FloatNonNegative]>:$cell_clip, + Confined, [TFL_FloatNonNegative]>:$proj_clip, // Since this op is the FULL kernel only, constrain it. Confined< DefaultValuedAttr, @@ -3498,6 +3752,24 @@ def TFL_UnidirectionalSequenceLSTMOp : LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, LstmResultConstraint, + TFL_OperandHasRankAtLeast<0, 2>, // input + TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights + TFL_OperandHasRank<2, 2>, // input_to_forget_weights + TFL_OperandHasRank<3, 2>, // input_to_cell_weights + TFL_OperandHasRank<4, 2>, // input_to_output_weights + TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights + TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights + TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights + TFL_OperandHasRank<8, 2>, // recurrent_to_output_weights + TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights + TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights + TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights + TFL_OperandIsNoneOrHasRank<12, 1>, // input_gate_bias + TFL_OperandHasRank<13, 1>, // forget_gate_bias + TFL_OperandHasRank<14, 1>, // cell_gate_bias + TFL_OperandHasRank<15, 1>, // output_gate_bias + TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights + TFL_OperandIsNoneOrHasRank<17, 2>, // projection_bias TFL_StatefulOp]> { let summary = "Unidirectional sequence lstm operator"; @@ -3513,35 +3785,35 @@ def TFL_UnidirectionalSequenceLSTMOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, I8]>:$input, + ins TFL_FpTensor:$input, // Weights - TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights, - TFL_TensorOf<[F32, I8]>:$input_to_forget_weights, - TFL_TensorOf<[F32, I8]>:$input_to_cell_weights, - TFL_TensorOf<[F32, I8]>:$input_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_output_weights, // Recurrent weights - TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights, // Cell weights - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_input_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_forget_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_output_weights, // Bias TFL_TensorOfOrNone<[F32]>:$input_gate_bias, - TFL_TensorOf<[F32]>:$forget_gate_bias, - TFL_TensorOf<[F32]>:$cell_bias, - TFL_TensorOf<[F32]>:$output_gate_bias, + TFL_FpTensor:$forget_gate_bias, + TFL_FpTensor:$cell_bias, + TFL_FpTensor:$output_gate_bias, // Projection weight and bias - TFL_TensorOfOrNone<[F32, I8]>:$projection_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights, // Optional input TFL_TensorOfOrNone<[F32]>:$projection_bias, @@ -3550,19 +3822,19 @@ def TFL_UnidirectionalSequenceLSTMOp : TFL_StatefulTensor:$input_cell_state, // Layer norm coefficients - TFL_TensorOfOrNone<[F32, I8]>:$input_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32, I8]>:$forget_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32, I8]>:$cell_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32, I8]>:$output_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$input_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$forget_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$output_layer_norm_coefficients, // Attributes TFL_AFAttr:$fused_activation_function, - DefaultValuedAttr:$cell_clip, - DefaultValuedAttr:$proj_clip, + Confined, [TFL_FloatNonNegative]>:$cell_clip, + Confined, [TFL_FloatNonNegative]>:$proj_clip, BoolAttr:$time_major ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QI8]>:$output); let hasOptions = 1; @@ -3759,15 +4031,14 @@ def TFL_BidirectionalSequenceLSTMOp : }]; } -def RnnResultConstraint : PredOpTrait< - "the input and result tensor elemental types must be same", - TCresVTEtIsSameAsOp<0, 0>>; - // UnidirectionalSequenceRNN op. -def TFL_UnidirectionalSequenceRNNOp : - TFL_Op<"unidirectional_sequence_rnn", - [RnnResultConstraint, TFL_StatefulOp]> { - +def TFL_UnidirectionalSequenceRNNOp : TFL_Op<"unidirectional_sequence_rnn", [ + TFL_OperandHasRank<4, 2>, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + PredOpTrait<"input and constant value operands must have same element type", + TFL_TCopVTEtAreSameAt<1, 2>>, + TFL_StatefulOp]> { let summary = "Unidirectional sequence rnn operator"; let description = [{ @@ -3784,16 +4055,16 @@ def TFL_UnidirectionalSequenceRNNOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, I8]>:$input, + ins TFL_FpTensor:$input, // Weights - TFL_TensorOf<[F32, I8]>:$input_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_input_weights, // Recurrent weights - TFL_TensorOf<[F32, I8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_input_weights, // Bias - TFL_TensorOf<[F32]>:$input_gate_bias, + TFL_FpTensor:$input_gate_bias, // Hidden state. TFL_StatefulTensor:$hidden_state, @@ -3803,7 +4074,7 @@ def TFL_UnidirectionalSequenceRNNOp : TFL_AFAttr:$fused_activation_function ); - let results = (outs TFL_TensorOf<[F32, I8]>:$output); + let results = (outs TFL_FpTensor:$output); let hasOptions = 1; @@ -3849,7 +4120,7 @@ def TFL_NumericVerifyOp : Op:$input, + TFL_TensorOf<[QI8, QUI8, QI16, F16, TFL_Quint8]>:$input, TFL_TensorOf<[F32]>:$ref, // Attributes @@ -3859,14 +4130,12 @@ def TFL_NumericVerifyOp : Op>; - // SVDF op. def TFL_SVDFOp : - TFL_Op<"svdf", - [SVDFResultConstraint, TFL_StatefulOp]> { + TFL_Op<"svdf", [ + PredOpTrait<"the input and result tensor elemental types must be same", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_StatefulOp]> { let summary = "Single value decomposition filter operator"; @@ -3878,13 +4147,13 @@ def TFL_SVDFOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, I8]>:$input, + ins TFL_TensorOf<[F32, QI8]>:$input, // Feature Weights. - TFL_TensorOf<[F32, I8]>:$feature_weights, + TFL_TensorOf<[F32, QI8, QUI8]>:$feature_weights, // Time weights - TFL_TensorOf<[F32, I8]>:$time_weights, + TFL_TensorOf<[F32, QI8]>:$time_weights, // Bias TFL_TensorOfOrNone<[F32]>:$input_gate_bias, @@ -3893,11 +4162,11 @@ def TFL_SVDFOp : TFL_StatefulTensor:$activation_state, // Attributes - I32Attr:$rank, + Confined:$rank, TFL_AFAttr:$fused_activation_function ); - let results = (outs TFL_TensorOf<[F32, I8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8]>:$output); let hasOptions = 1; @@ -3909,7 +4178,10 @@ def TFL_SVDFOp : }]; } -def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> { +def TFL_SegmentSumOp: TFL_Op<"segment_sum", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "SegmentSum operator"; let description = [{ @@ -3917,7 +4189,7 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I32]>:$data, + TFL_TensorOf<[F32, I32]>:$input, TFL_I32Tensor:$segment_ids ); let results = (outs TFL_TensorOf<[F32, I32]>:$output); @@ -3936,8 +4208,8 @@ def TFL_YieldOp : Op { } def TFL_WhileOp : Op, - SingleBlockImplicitTerminator<"YieldOp">]> { + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { let summary = [{While loop}]; let description = [{ @@ -3948,7 +4220,7 @@ def TFL_WhileOp : Op node_names; std::vector node_dtypes; std::vector> node_shapes; - std::vector node_mins; - std::vector node_maxs; + std::vector> node_mins; + std::vector> node_maxs; // Populate quantization specs. TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index c338b723a4a..ab80746f8b7 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer( std::vector node_names; std::vector node_dtypes; std::vector> node_shapes; - std::vector node_mins; - std::vector node_maxs; + std::vector> node_mins; + std::vector> node_maxs; // Populate quantization specs. TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( @@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer( saved_model_exported_names.begin(), saved_model_exported_names.end()); absl::Span exported_names(exported_names_in_vector); + if (exported_names.size() != 1) { + return errors::Unimplemented("Only support a single exported name."); + } + TF_ASSIGN_OR_RETURN(auto module, ImportSavedModel(model_flags.saved_model_dir(), model_flags.saved_model_version(), tags, diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 6dd44e666fb..8f2c8bc362c 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { return DT_STRING; case toco::IODataType::BOOL: return DT_BOOL; + case toco::IODataType::COMPLEX64: + return DT_COMPLEX64; default: return DT_INVALID; } @@ -175,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) { return RegisterCustomBuiltinOps(extra_tf_opdefs); } -Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - mlir::TFL::QuantizationSpecs* quant_specs, - std::vector* node_names, - std::vector* node_dtypes, - std::vector>* node_shapes, - std::vector* node_mins, - std::vector* node_maxs) { +Status PopulateQuantizationSpecs( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector>* node_mins, + std::vector>* node_maxs) { quant_specs->inference_input_type = ConvertIODataTypeToDataType(toco_flags.inference_input_type()); tensorflow::DataType inference_type = @@ -209,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, flag.shape().dims().end())); // Currently, only UINT8 and INT8 require inputs stats if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) { - TF_ASSIGN_OR_RETURN( - auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(), - inference_type)); - node_mins->push_back(min_max.first); - node_maxs->push_back(min_max.second); + if (flag.has_mean_value() && flag.has_std_value()) { + TF_ASSIGN_OR_RETURN( + auto min_max, InputStatsToMinMax(flag.mean_value(), + flag.std_value(), inference_type)); + node_mins->push_back(min_max.first); + node_maxs->push_back(min_max.second); + } else { + node_mins->push_back(llvm::None); + node_maxs->push_back(llvm::None); + } } } @@ -252,7 +258,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { std::string error_message; auto output = mlir::openOutputFile(filename, &error_message); if (!error_message.empty()) { - return errors::InvalidArgument("Failed to open file in %s.", filename); + return errors::InvalidArgument("Failed to open file in ", filename); } mlir::PassManager pm(module.getContext()); pm.addPass(mlir::createPrintOpGraphPass(output->os())); diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index 3ea36e5eb1d..87e73912a46 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags); // Populate quantization specs (or not) given user specified ranges for each // input arrays. -Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, - const toco::TocoFlags& toco_flags, - mlir::TFL::QuantizationSpecs* quant_specs, - std::vector* node_names, - std::vector* node_dtypes, - std::vector>* node_shapes, - std::vector* node_mins, - std::vector* node_maxs); +Status PopulateQuantizationSpecs( + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, + std::vector* node_dtypes, + std::vector>* node_shapes, + std::vector>* node_mins, + std::vector>* node_maxs); // Convert imported MLIR file to TfLite flatbuffer. // This will also run relevant passes as well. diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 23a65a88186..57417e95ec6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -3,6 +3,10 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library", ) +load( + "//third_party/mlir:tblgen.bzl", + "gentbl", +) package( default_visibility = [ @@ -23,6 +27,7 @@ package_group( exports_files([ "quantization_traits.h", "quantization_config.h", + "quantization_utils.h", ]) filegroup( @@ -34,6 +39,25 @@ filegroup( ], ) +gentbl( + name = "quantization_interfaces_inc_gen", + tbl_outs = [ + ( + "-gen-op-interface-decls", + "quantization_interface.h.inc", + ), + ( + "-gen-op-interface-defs", + "quantization_interface.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "quantization.td", + td_srcs = [ + ":quantization_td_files", + ], +) + tf_proto_library( name = "quantization_info_proto", srcs = [ @@ -71,9 +95,11 @@ cc_library( name = "quantization_lib", srcs = [ "quantization_driver.cc", + "quantization_interface.cc.inc", "quantization_utils.cc", ], hdrs = [ + "quantization_interface.h.inc", "quantization_traits.h", "quantization_utils.h", ], diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index b4fddceb580..2783297814b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -49,14 +49,16 @@ cc_library( ], hdrs = [ "tfl_to_std.h", + "//tensorflow/compiler/mlir/lite/quantization:quantization_utils.h", ], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "@com_google_absl//absl/strings", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 0ac3fa419bc..a2e3c065113 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { @@ -38,6 +39,7 @@ namespace lite { TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, + const tflite::TensorType& inference_type, const std::unordered_set& operator_names, bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, @@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel( // Apply quantization passes PassManager pm(module->getContext()); TFL::QuantizationSpecs quant_specs; - quant_specs.inference_type = tensorflow::DT_QINT8; + quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; quant_specs.disable_per_channel = disable_per_channel; @@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel( auto input_tf_type = tflite::TflTypeToTfType(input_type); if (input_tf_type == tensorflow::DT_FLOAT) { emit_adaptor = true; - } else if (input_tf_type == tensorflow::DT_UINT8) { - quant_specs.inference_type = tensorflow::DT_QUINT8; + } else if (input_tf_type == tensorflow::DT_UINT8 || + input_tf_type == tensorflow::DT_INT8 || + input_tf_type == tensorflow::DT_INT16) { + quant_specs.inference_type = input_tf_type; } pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs)); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 578aa6438de..d60df56b473 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -26,11 +26,13 @@ namespace mlir { namespace lite { // Quantize the `input_model` and write the result to a flatbuffer `builder`. -// The `input_type` and `output_type` can be float32/qint8/int8. +// The `input_type`, `output_type` and `inference_type` can be +// float32/qint8/int8/int16. // Return partially quantized model if `fully_quantize` is false. TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, + const tflite::TensorType& inference_type, const std::unordered_set& operator_names, bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 77bd87a3c03..5bd1b71e631 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( - *model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, + *model, tflite::TensorType_INT8, tflite::TensorType_INT8, + tflite::TensorType_INT8, {}, /*disable_per_channel=*/false, /*fully_quantize=*/true, builder, &error_reporter); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc index 8ea1709b15f..6b226fa68e7 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" namespace mlir { namespace TFL { @@ -47,12 +48,18 @@ void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) { auto dcast = b.create(dq.getLoc(), dq.getResult().getType(), dq.arg()); dq.getResult().replaceAllUsesWith(dcast); + if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) { + dcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr); + } dq.erase(); } else if (auto q = llvm::dyn_cast(op)) { auto out_type = q.getResult().getType(); auto qcast = b.create(q.getLoc(), out_type, q.arg(), TypeAttr::get(out_type)); q.getResult().replaceAllUsesWith(qcast); + if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) { + qcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr); + } q.erase(); } }); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td index 7bfcdb65686..c1e392bd3ad 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization.td +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>; // https://www.tensorflow.org/lite/performance/quantization_spec //===----------------------------------------------------------------------===// +// TODO(b/157870442): replace all FixedResultScale trait +def FixedOutputRangeInterface : OpInterface< + "FixedOutputRangeInterface"> { + let description = [{ + Interface for defining the fixed output range. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the fixed output range.}], + "UniformQuantizedType", "GetFixedOutputRange", + (ins "bool":$sign, "int":$bit_width) + >, + ]; +} + // Specify this trait if the op has a fixed output value range. class FixedResultScale : NativeOpTrait::Impl")>; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc index 6b897bd5608..3edd9c36760 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc @@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, absl::string_view inference_type, QuantizationSpecs* quant_specs) { std::vector input_nodes = absl::StrSplit(node_names, ','); - std::vector node_mins; + std::vector> node_mins; if (!min_values.empty()) { std::vector node_mins_str = absl::StrSplit(min_values, ','); for (int i = 0; i < node_mins_str.size(); i++) { @@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, } } - std::vector node_maxs; + std::vector> node_maxs; if (!max_values.empty()) { std::vector node_maxs_str = absl::StrSplit(max_values, ','); for (int i = 0; i < node_maxs_str.size(); i++) { @@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, quant_specs); } -bool GetInputNodeQuantSpecs(const std::vector& node_names, - const std::vector& node_mins, - const std::vector& node_maxs, - tensorflow::DataType inference_type, - QuantizationSpecs* quant_specs) { +bool GetInputNodeQuantSpecs( + const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) { quant_specs->inference_type = inference_type; // If min/max are not specified, just return; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index cac1df9eee1..0e766ec52b6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -69,7 +69,8 @@ struct QuantizationSpecs { // arguments. They are only used when `weight_quantization` is set to false, // and the model is required to have quantization parameters, either from // quantization aware training or calibration, for the remaining tensors. - std::vector> input_ranges; + std::vector, llvm::Optional>> + input_ranges; // The default ranges can be used when a tensor doesn't have quantization // parameters and couldn't be quantized. Used only for latency tests. @@ -90,7 +91,7 @@ struct QuantizationSpecs { bool RunWeightQuantization() const { return weight_quantization; } // Whether this inference type represents a signed storage type. - bool IsSignedInferenceType() { + bool IsSignedInferenceType() const { switch (inference_type) { case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT16: @@ -102,7 +103,7 @@ struct QuantizationSpecs { // Gets the width of this quantization type. Returns 0 if it isn't a // quantization type. - int64_t GetQuantizationTypeWidth() { + int64_t GetQuantizationTypeWidth() const { switch (inference_type) { case tensorflow::DT_QINT8: case tensorflow::DT_QUINT8: @@ -130,11 +131,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, // Gets the quantization specification for input arrays. The array names are not // stored in the spec, and will be matched by position. The min/max will be // ignored if the inference_type isn't a quantized type. Returns true if failed. -bool GetInputNodeQuantSpecs(const std::vector& node_names, - const std::vector& node_mins, - const std::vector& node_maxs, - tensorflow::DataType inference_type, - QuantizationSpecs* quant_specs); +bool GetInputNodeQuantSpecs( + const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, QuantizationSpecs* quant_specs); } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 2964a3e79f8..89443b1ec65 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -494,6 +494,13 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params, auto quantize = builder_.create(loc, new_type, value); auto dequantize = builder_.create( loc, expressed_type, quantize.getResult()); + + // This attribute is set to distinguish the quantize ops being added by the + // quantization pass. These ops can be removed without losing original + // program accuracy. + // TODO(fengliuai): make the attribute being part of op definition. + quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); + // `original_result` has a use to `quantize`, so this will replace that use // by the result of `dequantize`. Remember to reset that use afterwards value.replaceAllUsesWith(dequantize); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h index b59164b72e6..693f692c61a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -21,13 +21,18 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -namespace mlir { -namespace OpTrait { -namespace quant { - using QuantizedType = mlir::quant::QuantizedType; using UniformQuantizedType = mlir::quant::UniformQuantizedType; +namespace mlir { + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc" + +namespace OpTrait { +namespace quant { + // The base class that all the quantization related OpTrait implements. template class TraitType> struct QuantizationSpecTraitBase : public TraitBase { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 3d50f280d0f..32f68aaae5f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project @@ -436,6 +437,16 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func, llvm::SmallVector all_stats_ops; llvm::DenseSet redundant_stats_ops; + // Step 0: remove the quant::StatisticsOp which are used by the tfl.quantize + // op in case it overrides the information from training FakeQuant ops. + func.walk([&](quant::QuantizeCastOp q) { + auto input_op = q.arg().getDefiningOp(); + if (auto stats = llvm::dyn_cast_or_null(input_op)) { + q.setOperand(stats.arg()); + if (stats.use_empty()) stats.erase(); + } + }); + // Step 1: forward pass: propagate any value scales which are not produces // by `SameOperandsAndResultsScale`. Additionally, remove the value scales // which are produced by the `restricted_output_params`. diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 27ccc7d2b22..f17e44cd756 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -22,6 +22,8 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -35,11 +37,17 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" namespace mlir { namespace quant { +// A unit attribute can be attached to the quantize/dequantize ops which are +// added by the quantization passes. These ops can be removed erased without +// losing accuracy. +constexpr char kVolatileOpAttrName[] = "volatile"; + using QuantParams = quant::QuantizedType; using SignedInteger = std::pair; // bitwidth and sign using QuantParamsForResults = llvm::SmallVector; @@ -363,6 +371,55 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { } }; +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(RQ op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op.input(); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (llvm::isa(def) || + def->hasTrait() || + def->hasTrait()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector new_output_types; + for (auto result : def->getResults()) { + result.getUsers().begin()->dump(); + op.dump(); + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.qtype()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.createOperation(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + // Given a quantized type `input`, magnifying its scales by the factor stored in // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the // dimension size of `input` or isn't floating-point, nullptr will be returned. diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index 1f067aae685..5c69130c939 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> { return %1 : tensor<64xf32> // CHECK-LABEL: func @reshape_removeAdjacent -// CHECK: %cst = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: return +// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: return %[[RESHAPE]] } // Checks that tfl.reshape should be removed if its output has more than one @@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32> return %3 : tensor<64xf32> // CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse -// CHECK: %cst = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: %2 = addf %0, %1 -// CHECK: return %2 +// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]] +// CHECK: return %[[RESULT]] } // Checks that tfl.reshape should be kept if its output has more than one @@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32 return %0, %1 : tensor<16x4xf32>, tensor<64xf32> // CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse -// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32> -// CHECK: %cst_0 = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32> -// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: return %0, %1 +// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32> +// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32> +// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]] } // Checks that tfl.reshape should be removed if its output type is the same diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 4b8993e2b26..a8463d51c7e 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -8,13 +8,13 @@ func @add_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32> - // CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<6.000000e+00> : tensor - // CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> - // CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32> - // CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32> - // CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32> + // CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor + // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32> + // CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -33,10 +33,10 @@ func @add_int() -> (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %2 = constant dense< 4> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32> - // CHECK: %cst = constant dense<9> : tensor - // CHECK: %cst_0 = constant dense<6> : tensor<4xi32> - // CHECK: %cst_1 = constant dense<5> : tensor<4xi32> - // CHECK: %cst_2 = constant dense<2> : tensor<4xi32> + // CHECK: %[[CST:.*]] = constant dense<9> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32> + // CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32> + // CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> @@ -54,10 +54,10 @@ func @sub_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<3.000000e+00> : tensor - // CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32> - // CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -75,10 +75,10 @@ func @sub_int() -> (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %2 = constant dense< 4> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32> - // CHECK: %cst = constant dense<7> : tensor - // CHECK: %cst_0 = constant dense<10> : tensor<4xi32> - // CHECK: %cst_1 = constant dense<3> : tensor<4xi32> - // CHECK: %cst_2 = constant dense<6> : tensor<4xi32> + // CHECK: %[[CST:.*]] = constant dense<7> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32> + // CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32> + // CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> @@ -96,10 +96,10 @@ func @mul_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<6.750000e+00> : tensor - // CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32> - // CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor + // CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32> + // CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32> %5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_splat_dense_int @@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_same_shape @@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_trailing_dim @@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>, return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32> -// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32> -// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32> -// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32> -// CHECK: return %cst, %cst_0, %cst_1 +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32> +// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32> +// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32> +// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]] } // CHECK-LABEL: @add_dense_dense_int_mixing_1_n @@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> { %0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> -// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_splat_float @@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_splat_dense_float @@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_float_same_shape @@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_float_trailing_dim @@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32 return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32> -// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32> -// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32> -// CHECK: return %cst, %cst_0, %cst_1 +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32> +// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32> +// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32> +// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]] } // CHECK-LABEL: @add_dense_dense_float_mixfng_1_n @@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @rank func @rank() -> tensor<1xi32> { %cst = constant dense<[[1], [2]]> : tensor<2x1xi32> - // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return %[[CST]] %0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } // CHECK-LABEL: @rank_input_known_rank func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> { - // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return %[[CST]] %0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } @@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> { %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %shape = constant dense<[4]> : tensor<1xi32> - // CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: return %[[CST]] %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor { %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %shape = constant dense<[4]> : tensor<1xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor return %0 : tensor } @@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor { // CHECK-LABEL: @pseudo_const func @pseudo_const() -> tensor { - // CHECK: [[cst:%.*]] = constant dense<1> : tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<1> : tensor + // CHECK: return %[[CST]] %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor return %0 : tensor } @@ -356,8 +356,8 @@ func @range_int() -> tensor { %cst_1 = constant dense<4> : tensor %cst_2 = constant dense<1> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -368,8 +368,8 @@ func @range_float() -> tensor { %cst_1 = constant dense<4.0> : tensor %cst_2 = constant dense<1.0> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor { %cst_1 = constant dense<-4.0> : tensor %cst_2 = constant dense<-1.0> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor { %cst_1 = constant dense<7.0> : tensor %cst_2 = constant dense<1.5> : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> { %cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst_perm = constant dense<0> : tensor<1xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor { %cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst_perm = constant dense<0> : tensor<1xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor return %0 : tensor } @@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> { %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst_perm = constant dense<[1, 0]> : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> { %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst_perm = constant dense<[0, 1]> : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> { %cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32> %cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32> return %0 : tensor<4x2x3xi32> } @@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor { %87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor return %87 : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic @@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor { return %2 : tensor - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor + // CHECK: return %[[CST]] } // CHECK-LABEL: @concat_2_tensors_1_empty @@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> { %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32> return %3 : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32> - // CHECK: return [[cst]] : tensor<2xi32> + // CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32> + // CHECK: return %[[CST]] : tensor<2xi32> } // CHECK-LABEL: @concat_3_tensors_1_empty @@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor { %3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor return %3 : tensor - // CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"} + // CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: return %0 : tensor } @@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32> return %0 : tensor<2x2x3xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @concatConstantTensorsMiddleDim @@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32> return %0 : tensor<1x4x3xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @concatConstantTensorsLastDim @@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32> return %0 : tensor<1x2x6xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @div_dense_dense_float_mixfng_1_n @@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @div_dense_different_rank @@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> { return %0 : tensor<1x2x2xf32> -// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> +// CHECK: return %[[CST]] } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/back2back_fake_quant.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/back2back_fake_quant.pbtxt new file mode 100644 index 00000000000..31e2157d360 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/end2end/back2back_fake_quant.pbtxt @@ -0,0 +1,1186 @@ +# RUN: tf_tfl_translate --mlir-elide-elementsattrs-if-larger=10 --emit-builtin-tflite-ops \ +# RUN: --emit-select-tf-ops --tf-inference-type=DT_QUINT8 --tf-input-min-values=0.0 \ +# RUN: --tf-input-max-values=6.283185307179586 --tf-input-arrays=quant_dense_input --tf-input-shapes=1,1 \ +# RUN: --tf-output-arrays=Identity -o %t.tflite %s 2>&1 | FileCheck %s +# RUN: flatbuffer_to_string %t.tflite | FileCheck --check-prefix=RESULT1 %s + +node { + name: "quant_dense_input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } +} +node { + name: "sequential/quant_dense/MatMul/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 16 + } + } + tensor_content: "n\267\313\276@W\337>\227o9>PF\237\275|%}\276\333n\005>\371\005\031?\230\355\235\275\344\211_>\034\222\264=\254\003\345=Q\027*>\225\304\373>Qm6>\270\025\022;N/\020\277" + } + } + } +} +node { + name: "sequential/quant_dense/MatMul/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense/MatMul/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.5634322166442871 + } + } + } +} +node { + name: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5968852043151855 + } + } + } +} +node { + name: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + op: "Identity" + input: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "sequential/quant_dense/MatMul/ReadVariableOp" + input: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp" + input: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "sequential/quant_dense/MatMul/kquant/IdentityN" + op: "IdentityN" + input: "sequential/quant_dense/MatMul/kquant/FakeQuantWithMinMaxVars" + input: "sequential/quant_dense/MatMul/ReadVariableOp" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-10455" + } + } +} +node { + name: "sequential/quant_dense/MatMul" + op: "MatMul" + input: "quant_dense_input" + input: "sequential/quant_dense/MatMul/kquant/IdentityN" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "sequential/quant_dense/BiasAdd/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 16 + } + } + tensor_content: "\000\000\000\000L\020\341=\355\223\242\276\000\000\000\000\000\000\000\000\223&<>\206\234d\276\000\000\000\000\323%\241\276\331\305\234>\004\341\216>\3545e\276\032\363O\276\257:u\276\313\223\r>\000\000\000\000" + } + } + } +} +node { + name: "sequential/quant_dense/BiasAdd/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense/BiasAdd/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense/BiasAdd" + op: "BiasAdd" + input: "sequential/quant_dense/MatMul" + input: "sequential/quant_dense/BiasAdd/ReadVariableOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } +} +node { + name: "sequential/quant_dense/Relu" + op: "Relu" + input: "sequential/quant_dense/BiasAdd" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.0019266719464212656 + } + } + } +} +node { + name: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3.249699354171753 + } + } + } +} +node { + name: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + op: "Identity" + input: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "sequential/quant_dense/Relu" + input: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp" + input: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "sequential/quant_dense/oquant/IdentityN" + op: "IdentityN" + input: "sequential/quant_dense/oquant/FakeQuantWithMinMaxVars" + input: "sequential/quant_dense/Relu" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-10477" + } + } +} +node { + name: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.0019266719464212656 + } + } + } +} +node { + name: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3.1626083850860596 + } + } + } +} +node { + name: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + op: "Identity" + input: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "sequential/quant_dense/oquant/IdentityN" + input: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp" + input: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "sequential/quant_dense_1/iquant/IdentityN" + op: "IdentityN" + input: "sequential/quant_dense_1/iquant/FakeQuantWithMinMaxVars" + input: "sequential/quant_dense/oquant/IdentityN" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-10494" + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 16 + } + dim { + size: 16 + } + } + tensor_content: "\356x\007<\273\351H\275\2506\242\275\261t\332>\"O\'>m\"\010\276(\311b\275\301\233\262=d\327\031\275V\030`>1U\212\276\325\353,=\321\375+=\305\016-=IQC\276\272N\255\276\342\\\241=\253K\230>\221\247\220\274\026[\220=\301W9\276\3041\311>\200\013\255\276@\331\307>\250\246\320=e\305A\276\312\211\200>\252\220x\276?\306\305>jC\221>_h\304>\350\2701>\213&\014\277\327a\251>\010\321\365=\022\346\357>_P\326>\024@\260\276\010\002\203\275%\013;\276\300\270s\2762~\206\2769\341\356<;\364\023\276\'=\225\276\347\327\247\275\037\317\326\276\2379\243>\241+\306\276\252\226\007\276\232\273\010= \001s>h.\010=x\rK=\231\255\315\276\001\244>>_\227\304\276\030\255\255\276\t\013\014>\263\031i\276\312\312\204>\262\331\244=\367\331R\276\020h\005>\361\260\322a?\323>\t\236\304\275\240\215\240\275Y\034V>\200\236=<\232\336Y\276\274_\263\276\240\004\334\276\232\003\252>\264\243\273>\254Q\201>_\271\263\276a1\305>\237?\331>\244\031\275>\365K\351\276\212\302\234>\276{\217\275{\211{\276.\016\271\276\337\344\255\276\230\2347\2757O(>puO=\024\177\207\276v:\255>%.\230>\237\211\341=\327\316\037\274\302\205I>\020\234\201\276\027\356(\276J\320*\277\014\226\240=\222\312\323\276J\235\253=\263\201\221>\030\220\027\276\375g\"\277\227E\326\276\206\267{>\253\271\374\276!\235\203>\264\215\342=g\235\267\2767\211\306>x\323i>\244TJ>\306\216]\275\310\230\004\276\340\000B\274\351\324a>$[\266=\267\220\256\275\241n >\003s\235>\340\251\225\276\335jX=N\"@\276N;\271\276Sl6=\234S\007>~\345\025\277\013\020\362\275\377\333 \276@\227\325=\177\375\215=\335\357\332=t\217\027\276D\337\326=}\211\363\276\227\"\322>r\245\267=]\007x=4\372e>\252\235\373\276\030\276\221\275\2763\000\276\202\007\301>H\037Y>\307;\006?QT\241=NY\\<\014\254@>P\031\217\010\327+\276\212\240\305\276\371x7>$\350\037?I\332\262\276\361\307\031\276\275o\206>\355\217\016\276 \036\007>yQ\221=\372\204\207>#\301\214>\344\376\302\275y\024\261\2768\304+\276\260\n \276(\026\254>\300*\034=j\2142=[\023\233>\301\370\304\2766<0=\321\025\237>\264C\220\2765\301\330\276\274\027\252\276\n\331U\276\355\363\302>\214i\305\276\367`\273>\244\352\241\276c\030C\276u\033\251\275r\314\036>\033\352h\275\361\372\374\276\370\356\242\274O\305\030\277\360\264\221>\244 \'\276\223j\006>\334\320E>\274\223\031\277\275\210\326>\330\301o=9\273>>\324\003\314\275\234\204\215\276\321\253|\276\313v\275\276d\256\270>r,\013\27765J=\313s\'\276\035\363{\276\"3\251\275\007\267d<\266\261\256\2766\205g\276~\275\331>\272\317\242>\237\305\034>\312\1770\2769\356\024\277\255|Y\276\374V\271=\320b\177\276\014\023\345\2766\353\226\273\333\033\373\275.[\264\276\tx@\276Dx\273>Q\327\233>\246\265\353>\021\214\315\276\000K\356\272d\207\037\276\331Z\321\276\\>\001?\202\343\031?\000>^\2727\375V\276\266\006[=\367\377\313>;vB\275F\212\022?D\253\246\276\314\207\210>\255\222\211\276\230V\223\273z,\316\276\325D\206\276CR\260> \260-\275\273-*\276i\032d>\266\316\003>\300\322\205\276\232\205\322\276\036\267\205>\372Y\342=r\221k\276(\003k>" + } + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_1/MatMul/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.6485296487808228 + } + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.6135242581367493 + } + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + op: "Identity" + input: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_1/MatMul/ReadVariableOp" + input: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp" + input: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "sequential/quant_dense_1/MatMul/kquant/IdentityN" + op: "IdentityN" + input: "sequential/quant_dense_1/MatMul/kquant/FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_1/MatMul/ReadVariableOp" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-10513" + } + } +} +node { + name: "sequential/quant_dense_1/MatMul" + op: "MatMul" + input: "sequential/quant_dense_1/iquant/IdentityN" + input: "sequential/quant_dense_1/MatMul/kquant/IdentityN" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "sequential/quant_dense_1/BiasAdd/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 16 + } + } + tensor_content: "Y\226;>\002#==\315\253\207>\032t\021\276\2510O\273\354\026\001\276\000\000\000\000\227\021M>\275D\213>\000\000\000\000\231\354)\276]@\237>\026\327E=p\007\003>\007\340c>\\\241\336\275" + } + } + } +} +node { + name: "sequential/quant_dense_1/BiasAdd/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_1/BiasAdd/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/BiasAdd" + op: "BiasAdd" + input: "sequential/quant_dense_1/MatMul" + input: "sequential/quant_dense_1/BiasAdd/ReadVariableOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } +} +node { + name: "sequential/quant_dense_1/Relu" + op: "Relu" + input: "sequential/quant_dense_1/BiasAdd" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.0019266719464212656 + } + } + } +} +node { + name: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 2.7181646823883057 + } + } + } +} +node { + name: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + op: "Identity" + input: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_1/Relu" + input: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp" + input: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "sequential/quant_dense_1/oquant/IdentityN" + op: "IdentityN" + input: "sequential/quant_dense_1/oquant/FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_1/Relu" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-10535" + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 16 + } + dim { + size: 1 + } + } + tensor_content: "\217\017J?$\023*>\352\265\272\276\304|;\276\2178\277=\275\031\026\277J\017\304\276Iu!?h]\314\276A7;\276\204\221\031\275\323\2259\277\341K\304>Uzs?@)\337>\"\300n\276" + } + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_2/MatMul/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.7017847895622253 + } + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.9041997790336609 + } + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + op: "Identity" + input: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_2/MatMul/ReadVariableOp" + input: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp" + input: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "sequential/quant_dense_2/MatMul/kquant/IdentityN" + op: "IdentityN" + input: "sequential/quant_dense_2/MatMul/kquant/FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_2/MatMul/ReadVariableOp" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-10554" + } + } +} +node { + name: "sequential/quant_dense_2/MatMul" + op: "MatMul" + input: "sequential/quant_dense_1/oquant/IdentityN" + input: "sequential/quant_dense_2/MatMul/kquant/IdentityN" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "sequential/quant_dense_2/BiasAdd/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 0.04556306451559067 + } + } + } +} +node { + name: "sequential/quant_dense_2/BiasAdd/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_2/BiasAdd/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_2/BiasAdd" + op: "BiasAdd" + input: "sequential/quant_dense_2/MatMul" + input: "sequential/quant_dense_2/BiasAdd/ReadVariableOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } +} +node { + name: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: -0.8735198974609375 + } + } + } +} +node { + name: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp" + op: "Identity" + input: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0447778701782227 + } + } + } +} +node { + name: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + op: "Identity" + input: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1/resource" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars" + op: "FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_2/BiasAdd" + input: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp" + input: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars/ReadVariableOp_1" + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "sequential/quant_dense_2/oquant/IdentityN" + op: "IdentityN" + input: "sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars" + input: "sequential/quant_dense_2/BiasAdd" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-10575" + } + } +} +node { + name: "sequential/output/FakeQuantWithMinMaxArgs" + op: "FakeQuantWithMinMaxArgs" + input: "sequential/quant_dense_2/oquant/IdentityN" + attr { + key: "max" + value { + f: 0.9921875 + } + } + attr { + key: "min" + value { + f: -1.0 + } + } + attr { + key: "narrow_range" + value { + b: false + } + } + attr { + key: "num_bits" + value { + i: 8 + } + } +} +node { + name: "Identity" + op: "Identity" + input: "sequential/output/FakeQuantWithMinMaxArgs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 175 +} + +# CHECK: {{.*: warning:}} loc(fused["sequential/quant_dense/oquant/FakeQuantWithMinMaxVars"{{.*\): }}quantizer's output has another quantizer +# CHECK: {{.*: warning:}} loc(fused["sequential/quant_dense_2/oquant/FakeQuantWithMinMaxVars"{{.*\): }}quantizer's output has another quantizer + +# RESULT1: name: "Identity" +# RESULT1-NEXT: quantization: +# RESULT1-NEXT: scale: [ 0.007523 ], +# RESULT1-NEXT: zero_point: [ 116 ] + +# TODO Actually RESULT1 represents in incomplete implementation +# Currently TF2.2.0-rc3 all but the first fake_quant in +# sequence are dropped. A correct transalation would retain the second as a requantization +# op. + +# CORRECT1:name: "Identity" +# CORRECT1-NEXT: quantization: +# CORRECT1-NEXT: scale: [ 0.007813 ], +# CORRECT1-NEXT: zero_point: [ 128 ] + diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt new file mode 100644 index 00000000000..096033e37cb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt @@ -0,0 +1,101 @@ +# RUN: tf_tfl_translate -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=2,5,3:3,7 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-output-arrays=MatMul -output-mlir %s -o - 2>&1 | FileCheck %s + +node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + dim { + size: 5 + } + dim { + size: 3 + } + } + } + } +} +node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 7 + } + } + } + } +} +node { + name: "MatMul" + op: "BatchMatMulV2" + input: "Placeholder" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "adj_x" + value { + b: false + } + } + attr { + key: "adj_y" + value { + b: false + } + } +} +versions { + producer: 175 +} + +# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} { +# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32> +# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32> +# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32> +# CHECK: %[[VAL_5:.*]] = constant unit +# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32> +# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32> +# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32> +# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32> +# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32> +# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> +# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32> +# CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> +# CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_9]]) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<1x3x7xf32> +# CHECK: %[[VAL_15:.*]] = "tfl.slice"(%[[VAL_14]], %[[VAL_8]], %[[VAL_9]]) : (tensor<1x3x7xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x3x7xf32> +# CHECK: %[[VAL_16:.*]] = "tfl.reshape"(%[[VAL_15]], %[[VAL_4]]) : (tensor<1x3x7xf32>, tensor<2xi32>) -> tensor<3x7xf32> +# CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_16]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32> +# CHECK: %[[VAL_18:.*]] = "tfl.fully_connected"(%[[VAL_11]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_19:.*]] = "tfl.fully_connected"(%[[VAL_13]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_20:.*]] = "tfl.pack"(%[[VAL_18]], %[[VAL_19]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> +# CHECK: return %[[VAL_20]] : tensor<2x5x7xf32> +# CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir index 0dd8ddc4c91..d793ea2d62f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir @@ -1,15 +1,15 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // Ensure lstm roundtrip exactly -func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { +func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> // CHECK-LABEL: main // seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252 // CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( { -// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> +// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES0]] } diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir new file mode 100644 index 00000000000..f08ac0e1027 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir @@ -0,0 +1,14 @@ +// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s --dump-input-on-failure +module { + + func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor, tensor) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { + %0 = "tf.op1"(%arg0) : (tensor<1x!tf.string>) -> (tensor) + %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %2:2 = "tf.op2"(%arg0, %1) : (tensor<1x!tf.string>, tensor) -> (tensor, tensor) + return %2#0, %2#1 : tensor, tensor + } + + // CHECK: func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor, tensor) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { + // CHECK: "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor, tensor) + // CHECK: return %0#0, %0#1 : tensor, tensor +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 15b6bf56b7a..15c73d2db2c 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2 // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } +func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { + %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> + return %1 : tensor<2x3xi32> + +// CHECK-LABEL: concatv2I64Axis +// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> +} + func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 9b1eeab3d7c..a7fb5b1666e 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -292,7 +292,7 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten // CHECK: [[SIZE_DIFF:%.*]] = "tf.Sub"([[SIZE]], [[INPUT_SIZE]]) : (tensor, tensor) -> tensor // CHECK: [[DIFF_RES:%.*]] = "tf.Greater"([[SIZE_DIFF]], [[ZERO]]) : (tensor, tensor) -> tensor // CHECK: [[SHAPE_1:%.*]] = "tf.Shape"([[INPUT]]) : (tensor<3x10xf32>) -> tensor -// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, output_shapes = [], then_branch = @cond_true} : (tensor, tensor<3x10xf32>, tensor, tensor, tensor) -> tensor +// CHECK: [[RESULT:%.*]] = "tf.If"([[DIFF_RES]], [[INPUT]], [[SHAPE_1]], [[SIZE_DIFF]], [[SIZE]]) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor<3x10xf32>, tensor, tensor, tensor) -> tensor // CHECK: return [[RESULT]] : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir index 1b46fa3d0e5..320f869ac4c 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir @@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: inputs: [ 2, 1 ], // CHECK-NEXT: outputs: [ 3 ], -// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ] +// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ] // CHECK-NEXT: }, { // CHECK-NEXT: opcode_index: 2, // CHECK-NEXT: inputs: [ 3 ], diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index e278572cd1e..ef78f993cc4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -1,6 +1,6 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s -func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> { +func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { @@ -72,21 +72,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 10, // CHECK-NEXT: name: "arg9", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 11, // CHECK-NEXT: name: "arg10", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 12, // CHECK-NEXT: name: "arg11", // CHECK-NEXT: quantization: { @@ -100,21 +100,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 14, // CHECK-NEXT: name: "arg13", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 15, // CHECK-NEXT: name: "arg14", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 16, // CHECK-NEXT: name: "arg15", // CHECK-NEXT: quantization: { @@ -128,7 +128,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 18, // CHECK-NEXT: name: "arg17", // CHECK-NEXT: quantization: { @@ -261,9 +261,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: -^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>): +^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>): %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir index 8e579421b0b..d9bba58b7d7 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -1,6 +1,6 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s -func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { +func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { @@ -9,63 +9,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 1, // CHECK-NEXT: name: "arg0", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 2, // CHECK-NEXT: name: "arg1", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 3, // CHECK-NEXT: name: "arg2", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 4, // CHECK-NEXT: name: "arg3", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 5, // CHECK-NEXT: name: "arg4", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 6, // CHECK-NEXT: name: "arg5", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 7, // CHECK-NEXT: name: "arg6", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 8, // CHECK-NEXT: name: "arg7", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 9, // CHECK-NEXT: name: "arg8", // CHECK-NEXT: quantization: { @@ -121,63 +121,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 17, // CHECK-NEXT: name: "arg16", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 18, // CHECK-NEXT: name: "arg17", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 19, // CHECK-NEXT: name: "arg18", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 20, // CHECK-NEXT: name: "arg19", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 21, // CHECK-NEXT: name: "arg20", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 22, // CHECK-NEXT: name: "arg21", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: name: "Const", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: }, // CHECK-NEXT: is_variable: true // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: name: "Const1", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: }, // CHECK-NEXT: is_variable: true // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 25, // CHECK-NEXT: name: "tfl.unidirectional_sequence_lstm", // CHECK-NEXT: quantization: { @@ -244,9 +244,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { @@ -259,9 +259,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: } // CHECK-EMPTY: -^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>): - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %2 : tensor<4xf32> +^bb0(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4x4xf32>, %arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>): + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %2 : tensor<4x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir index 7ba24bd5c51..f2b99bcd0df 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir @@ -37,7 +37,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: name: "Const", // CHECK-NEXT: quantization: { // CHECK-EMPTY: @@ -76,7 +76,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { @@ -90,7 +90,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-EMPTY: ^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>): - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index b1d1d81af37..3451f28380b 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -190,9 +190,9 @@ func @testSquare(tensor) -> tensor { return %0 : tensor } -func @testQuantizedResizeNearestNeighbor(tensor>, tensor) -> tensor> { -^bb0(%arg0: tensor>, %arg1: tensor): - %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor>, tensor) -> tensor> +func @testQuantizedResizeNearestNeighbor(tensor>, tensor) -> tensor> { +^bb0(%arg0: tensor>, %arg1: tensor): + %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor>, tensor) -> tensor> return %0 : tensor> } @@ -573,7 +573,7 @@ func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> { // test invalid Logistic input func @testLogisticWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}} + // expected-error @+1 {{'tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or TFLite quint8 type values, but got 'tensor'}} %0 = "tfl.logistic"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -581,36 +581,36 @@ func @testLogisticWithWrongInputType(tensor) -> tensor { // ----- // CHECK-LABEL: testUnidirectionalSequenceRnn -func @testUnidirectionalSequenceRnn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor +func @testUnidirectionalSequenceRnn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstmWithoutProjection -func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: none, %arg17: none, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, none, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, none, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor +func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: none, %arg17: none, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, none, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, none, none, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstm -func @testUnidirectionalSequenceLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor +func @testUnidirectionalSequenceLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr -func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor +func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -663,10 +663,10 @@ func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { +func @testLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) - // CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -689,10 +689,10 @@ func @testQuantizedBasicLstm(%arg0: tensor<1x384x!quant.uniform, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { +func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor, %arg1: none, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) - // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor, none, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -707,11 +707,11 @@ func @testLstmWithInvalidNoneType(%arg0: tensor, %arg1: tensor // ----- -// test invalid input dimension, the first input operand for lstm op should be at least 2D tensor. +// test invalid input dimension, the third input operand for lstm op should be 2-D tensor. func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - // expected-error @+1 {{'tfl.lstm' op the first input operand should have more than 2 dimensions.}} + // expected-error @+1 {{'tfl.lstm' op failed to verify that operand 2 is 2-D}} %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %24 : tensor<4xf32> @@ -720,22 +720,22 @@ func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 // ----- // 'input_to_output_weights' input for lstm op has unmatched rank with `input`. -func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { +func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") // expected-error @+1 {{'tfl.lstm' op inputs don't match with the dimensions.}} - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> } // ----- // Coefficient inputs of LSTM op don't match the dimension with input operand `input_to_output_weights`. -func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> { +func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") // expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}} - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> } @@ -1201,7 +1201,7 @@ func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) // ----- func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor { - // expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float or QI8 type or QUI8 type values}} + // expected-error @+1 {{'tfl.resize_bilinear' op failed to verify that input and output must have same element type}} %0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor } @@ -1252,10 +1252,10 @@ func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, % // ----- -func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xi8> { - // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values}} - %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi8> - return %0 : tensor<*xi8> +func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xi16> { + // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer or 8-bit signless integer or 8-bit unsigned integer values, but got 'tensor<*xi16>'}} + %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi16> + return %0 : tensor<*xi16> } // ----- @@ -1489,7 +1489,8 @@ func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor>) -> tensor<1x56x56x192x!quant.uniform> { +func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> { + // expected-error @+1 {{'tfl.local_response_normalization' op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x56x56x192x!quant.uniform>'}} %0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> return %0 : tensor<1x56x56x192x!quant.uniform> } @@ -1498,8 +1499,8 @@ func @testQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant. // CHECK-LABEL: testSvdf func @testSvdf(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { - // CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor - %0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -1523,32 +1524,32 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x // ----- -func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> { - // expected-error @+1 {{'input' and 'output' should have the same rank}} - %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32> - return %0 : tensor<10x10x10xf32> +func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<10x10x10x10xf32>) -> tensor<10x10xf32> { + // expected-error @+1 {{'tfl.prelu' op result type '10x10' not broadcast compatible with broadcasted operands's shapes '10x10x10x10'}} + %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> + return %0 : tensor<10x10xf32> } // ----- func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> { - // expected-error @+1 {{'input' and 'output' should have the same shape}} + // expected-error @+1 {{'tfl.prelu' op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}} %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> return %0 : tensor<1x2x3x5xf32> } // ----- -func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> { +func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> { // expected-error @+1 {{'alpha' should have one less rank than 'input'.}} - %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> + %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> return %0 : tensor<7x3x2x14xf32> } // ----- func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> { - // expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}} + // expected-error @+1 {{'tfl.prelu' op operands don't have broadcast-compatible shapes}} %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> return %0 : tensor<15x14x2x14xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 2815afd14b9..3f8257b54f0 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -439,6 +439,19 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor< // CHECK: return %[[rs2]] } +// CHECK-LABEL: @NotReorderReshapeAddIfHighDim +func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor<1x30x96xf32> { + %cst = constant dense<2.0> : tensor + %shape = constant dense<[1, 30, 96]> : tensor<3xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x30x96xf32>, tensor<3xi32>) -> tensor<1x30x96xf32> + %2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x30x96xf32>, tensor) -> tensor<1x30x96xf32> + return %2 : tensor<1x30x96xf32> + + // CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 + // CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]] + // CHECK: return %[[rs2]] +} + // CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> { %shape = constant dense<[40, 40]> : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 5377c4fdb98..6573a2f1c36 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor) -> (tensor<2xf32>,t // CHECK-NEXT: return %[[split]]#0, %[[split]]#1 } +// CHECK-LABEL: RemoveTrival +func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform>, %arg1: tensor<128x512x!quant.uniform:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform> { + %1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform>, tensor<128x512x!quant.uniform:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform> + %2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform>} : (tensor<384x128x!quant.uniform>) -> tensor<384x128x!quant.uniform> + return %2 : tensor<384x128x!quant.uniform> + +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform> +// CHECK-NEXT: return %[[fc]] +} + func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { %cst = constant dense<[1, 1001]> : tensor<2xi32> %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index f6054f3d65d..e1f496b91f4 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -63,7 +63,7 @@ func @prepareAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { return %add : tensor<2x2xf32> // CHECK: %[[cst:.*]] = constant dense<[{{\[}}0.000000e+00, 1.000000e+00], [2.000000e+00, 2.550000e+02]]> -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<2x2x!quant.uniform>} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<2x2x!quant.uniform>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[add:.*]] = tfl.add %arg0, %[[dq]] // CHECK: return %[[add]] @@ -83,7 +83,7 @@ func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> { // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]] // PerTensor: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<3x3x3x3xf32> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x3x3x3x!quant.uniform:f32, 1.000000e+00>>} +// PerTensor: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x3x3x3x!quant.uniform:f32, 1.000000e+00>>, volatile} // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // PerTensor: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]] } @@ -97,7 +97,7 @@ func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> { // CHECK: %[[cst:.*]] = constant dense<[{{\[\[\[}}0.000000e+00]]], [{{\[\[}}1.270000e+02]]], [{{\[\[}}-1.270000e+02]]]]> // CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<3x1x1x1x!quant.uniform:f32:0, -// CHECK-SAME: {0.0078740157480314959,1.000000e+00,1.000000e+00}>>} +// CHECK-SAME: {0.0078740157480314959,1.000000e+00,1.000000e+00}>>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[dq]] @@ -134,12 +134,12 @@ func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112 return %fc : tensor<1x112x112x32xf32> // CHECK: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32> -// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform:f32, 1.000000e+00>> +// CHECK: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> // CHECK: "tfl.fully_connected"(%arg0, %[[dq]] // PerTensor: %[[cst:.*]] = constant dense<1.270000e+02> : tensor<32x12xf32> -// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>} : (tensor<32x12xf32>) -> tensor<32x12x!quant.uniform:f32, 1.000000e+00>> +// PerTensor: %[[q:.*]] = "tfl.quantize"(%cst) {qtype = tensor<32x12x!quant.uniform:f32, 1.000000e+00>>, volatile} // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> // PerTensor: "tfl.fully_connected"(%arg0, %[[dq]] } diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index f937d0afd4d..38f76bb4eb5 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -67,7 +67,7 @@ func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform // CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32> -// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>} +// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) @@ -87,7 +87,7 @@ func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform // CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32> -// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>} +// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]]) // CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1) @@ -107,7 +107,7 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform return %6 : tensor<1x112x112x32x!quant.uniform> // CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>} +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() @@ -129,7 +129,7 @@ func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform> // CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>} +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() @@ -151,7 +151,7 @@ func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform> // CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>} +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) // CHECK: %3 = "tfl.pseudo_qconst"() @@ -232,7 +232,7 @@ func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform>, tensor< // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x2x5x!quant.uniform>} +// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x2x5x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor<1x2x2x5xf32> } @@ -277,7 +277,7 @@ func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform // CHECK: %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) // CHECK: %1 = "tfl.reshape"(%0, %{{.*}}) : (tensor<1x6x6x16xf32>, tensor<3xi32>) -> tensor<1x36x16xf32> -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform>} +// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x36x16x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x36x16x!quant.uniform>) // CHECK: return %3 : tensor<1x36x16xf32> } @@ -291,7 +291,7 @@ func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.softmax"(%0) {beta = 1.000000e+00 : f32} : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>} +// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) // CHECK: return %3 : tensor<1x6x6x16xf32> } @@ -305,7 +305,7 @@ func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform> // CHECK: %0 = "tfl.dequantize"(%arg0) // CHECK: %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> -// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>} +// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> // CHECK: return %3 : tensor<1x6x6x16xf32> } @@ -327,7 +327,7 @@ func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform>) -> ten // CHECK: %[[in:.*]] = "tfl.dequantize"(%arg0) // CHECK: %[[l2:.*]] = "tfl.l2_normalization"(%[[in]]) -// CHECK: %[[q:.*]] = "tfl.quantize"(%[[l2]]) {qtype = tensor<1x6x6x16x!quant.uniform>} +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[l2]]) {qtype = tensor<1x6x6x16x!quant.uniform>, volatile} // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: return %[[dq]] : tensor<1x6x6x16xf32> } @@ -350,13 +350,13 @@ func @QuantizeConcatOperand0ToAll(tensor<1x2x!quant.uniform>, t %1 = "tfl.concatenation"(%0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> return %1 : tensor<2x2xf32> -// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>, volatile} // CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> -// CHeCK: return %5 : tensor<2x2xf32> +// CHECK: return %5 : tensor<2x2xf32> } // CHECK-LABEL: QuantizeConcatOperand1ToAll @@ -366,11 +366,11 @@ func @QuantizeConcatOperand1ToAll(tensor<1x2xf32>, tensor<1x2x!quant.uniform, tensor<1x2xf32>) -> tensor<2x2xf32> return %1 : tensor<2x2xf32> -// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg1) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: %4 = "tfl.quantize"(%3) {qtype = tensor<2x2x!quant.uniform>, volatile} // CHECK: %5 = "tfl.dequantize"(%4) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> // CHECK: return %5 : tensor<2x2xf32> } @@ -382,9 +382,9 @@ func @QuantizeConcatResToAll(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!qu %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %1 : tensor<2x2x!quant.uniform> -// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %4 = "tfl.concatenation"(%3, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> // CHECK: %5 = "tfl.quantize"(%4) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> @@ -399,7 +399,7 @@ func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %2 : tensor<2x2x!quant.uniform> -// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %3 = "tfl.concatenation"(%2, %1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> @@ -416,7 +416,7 @@ func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tens %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> // CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> @@ -434,7 +434,7 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> // CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> @@ -475,22 +475,22 @@ func @QuantizeChain(tensor<1x224x224x3x!quant.uniform> return %10 : tensor<1x36x16xf32> // CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>} +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) // CHECK: %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) // CHECK: %3 = "tfl.pseudo_qconst"() // CHECK: %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) // CHECK: %5 = "tfl.average_pool_2d"(%2) -// CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform>} +// CHECK: %6 = "tfl.quantize"(%5) {qtype = tensor<1x224x224x3x!quant.uniform>, volatile} // CHECK: %7 = "tfl.dequantize"(%6) : (tensor<1x224x224x3x!quant.uniform>) // CHECK: %8 = "tfl.conv_2d"(%7, %4, %1) // CHECK: %9 = "tfl.quantize"(%8) {qtype = tensor<1x112x112x32x!quant.uniform>} // CHECK: %10 = "tfl.dequantize"(%9) : (tensor<1x112x112x32x!quant.uniform>) // CHECK: %11 = "tfl.reshape"(%10, %{{.*}}) -// CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform>} +// CHECK: %12 = "tfl.quantize"(%11) {qtype = tensor<1x36x16x!quant.uniform>, volatile} // CHECK: %13 = "tfl.dequantize"(%12) : (tensor<1x36x16x!quant.uniform>) // CHECK: %14 = "tfl.softmax"(%13) -// CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform>} +// CHECK: %15 = "tfl.quantize"(%14) {qtype = tensor<1x36x16x!quant.uniform>, volatile} // CHECK: %16 = "tfl.dequantize"(%15) : (tensor<1x36x16x!quant.uniform>) // CHECK: return %16 : tensor<1x36x16xf32> } @@ -501,7 +501,7 @@ func @QuantizeConstant() -> tensor<2x3xf32> { return %cst : tensor<2x3xf32> // CHECK: %cst = constant dense{{.*}}tensor<2x3xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform>} +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform>, volatile} // CHECK: %1 = "tfl.dequantize"(%0) // CHECK: return %1 : tensor<2x3xf32> } @@ -521,7 +521,7 @@ func @QuantizeZeroSplat() -> tensor<2x3xf32> { return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor<2x3xf32> -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeZeroScalar @@ -530,7 +530,7 @@ func @QuantizeZeroScalar() -> tensor { return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = constant dense<0.000000e+00> : tensor -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} } // CHECK-LABEL: QuantizePositiveSplat @@ -539,7 +539,7 @@ func @QuantizePositiveSplat() -> tensor<2x3xf32> { return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = constant dense<2.540000e+01> : tensor<2x3xf32> -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizePositiveScalar @@ -548,7 +548,7 @@ func @QuantizePositiveScalar() -> tensor { return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = constant dense<2.540000e+00> : tensor -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} } // CHECK-LABEL: QuantizeNegativeSplat @@ -557,7 +557,7 @@ func @QuantizeNegativeSplat() -> tensor<2x3xf32> { return %cst : tensor<2x3xf32> // CHECK-NEXT: %[[cst:.*]] = constant dense<-2.540000e+00> : tensor<2x3xf32> -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor<2x3x!quant.uniform>, volatile} } // CHECK-LABEL: QuantizeNegativeScalar @@ -566,7 +566,7 @@ func @QuantizeNegativeScalar() -> tensor { return %cst : tensor // CHECK-NEXT: %[[cst:.*]] = constant dense<-2.540000e+01> : tensor -// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>} +// CHECK-NEXT: "tfl.quantize"(%[[cst]]) {qtype = tensor>, volatile} } // CHECK-LABEL: QuantizeSharedBiases @@ -617,7 +617,7 @@ func @QuantizeSharedBiases2( // CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) // CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) // CHECK: %[[cst_0:.*]] = constant dense<0.000000e+00> : tensor<32xf32> -// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform>} +// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) {qtype = tensor<32x!quant.uniform>, volatile} // CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) // CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]] // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 54ca7f043f4..6f42ae6293d 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -195,8 +195,8 @@ func @QuantizeConcat(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.unif %1 = "tfl.quantize"(%0) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %1 : tensor<2x2x!quant.uniform> -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>, volatile} // CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q1]], %[[q0]]) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } @@ -209,8 +209,8 @@ func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform>, tens %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>} -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x2x!quant.uniform>, volatile} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} // CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } diff --git a/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir b/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir index d2d0e43e0e9..c5c9ee645f4 100644 --- a/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir +++ b/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir @@ -1,27 +1,27 @@ // RUN: tf-opt -tfl-split-merged-operands %s | FileCheck %s -func @testSingleLstm(%arg0: tensor<4 x f32>) -> tensor<4xf32> { +func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { // CHECK-LABEL: testSingleLstm - // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %1 : tensor<4xf32> + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> } -func @testMultipleLstms(%arg0: tensor<4 x f32>) -> tensor<4xf32> { +func @testMultipleLstms(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { // CHECK-LABEL: testMultipleLstms - // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %2 : tensor<4xf32> + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %2 : tensor<4x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 6ab16141626..40420eee697 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, quant_specs.default_ranges.second.hasValue()) { pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( quant_specs.default_ranges.first.getValueOr(0.0), - quant_specs.default_ranges.second.getValueOr(0.0))); + quant_specs.default_ranges.second.getValueOr(0.0), + quant_specs.IsSignedInferenceType())); pass_manager->addPass(mlir::TFL::CreateQuantizePass()); pass_manager->addPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); @@ -161,6 +162,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass( mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); pass_manager->addNestedPass(mlir::createCanonicalizerPass()); + if (pass_config.shape_inference) { + // Add a shape inference pass to optimize away the unnecessary casts. + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } pass_manager->addPass( mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); pass_manager->addPass(mlir::TFL::CreateOptimizePass()); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 4bc9d9e0c2d..31dad60c294 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -160,6 +160,11 @@ int main(int argc, char **argv) { absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::Span exported_names(exported_names_vector); + if (exported_names.size() != 1) { + llvm::errs() << "There should be only one exported name"; + return kTrFailure; + } + module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags, exported_names, &context); } else { @@ -175,7 +180,7 @@ int main(int argc, char **argv) { if (!module.ok()) return kTrFailure; mlir::PassManager pm(&context); - applyPassManagerCLOptions(pm); + mlir::applyPassManagerCLOptions(pm); // Set the quantization specifications from the command line flags. mlir::TFL::QuantizationSpecs quant_specs; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index b9ec67736d9..62f64ab63b4 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -174,7 +174,7 @@ StatusOr ImportSavedModel( return module; } else if (saved_model_version == 1) { auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, context); + input_filename, tags, exported_names, context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index a1602baced5..c23ae9fcfab 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -46,8 +46,11 @@ namespace { class DefaultQuantParamsPass : public PassWrapper { public: - explicit DefaultQuantParamsPass(double default_min, double default_max) - : default_min_(default_min), default_max_(default_max) {} + explicit DefaultQuantParamsPass(double default_min, double default_max, + bool is_signed) + : default_min_(default_min), + default_max_(default_max), + is_signed_(is_signed) {} void runOnFunction() override; @@ -82,6 +85,7 @@ class DefaultQuantParamsPass double default_min_; double default_max_; + bool is_signed_; quant::QuantParams default_quant_params_; }; } // namespace @@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams( default_quant_params_ = quant::fakeQuantAttrsToType( builder.getUnknownLoc(), /*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false, - builder.getF32Type()); + builder.getF32Type(), is_signed_); } return default_quant_params_; } // Creates an instance of the default quant parameters pass. std::unique_ptr> CreateDefaultQuantParamsPass( - double default_min, double default_max) { - return absl::make_unique(default_min, default_max); + double default_min, double default_max, bool is_signed) { + return absl::make_unique(default_min, default_max, + is_signed); } // Registers this pass with default values, only for test @@ -230,7 +235,8 @@ static PassRegistration pass( "tfl-default-quant", "Apply quantization with default quantization parameter", [] { return CreateDefaultQuantParamsPass(/*default_min=*/-1.0, - /*default_max=*/1.0); + /*default_max=*/1.0, + /*is_signed=*/false); }); } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index 201a0bb2481..9b526f40277 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -321,7 +321,8 @@ void DenseToSparse::runOnFunction() { if (result.needs_densify) { const auto value = op->getOperand(operand); - auto densify = builder.create(op->getLoc(), value); + auto densify = + builder.create(op->getLoc(), value.getType(), value); value.replaceAllUsesWith(densify); densify.setOperand(value); } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 4c6a16c2233..3e9d6e488b8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -33,9 +33,6 @@ class ExtractI32At : NativeCodeCall< "$_builder.getI32IntegerAttr($_self.cast().getValue()[" # i # "].cast().getInt())">; -// Merge the two Attributes to a ArrayAttr; -def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">; - // Use the tensor type information from $0 and convert min $1, max $2 and // numBits $3 and narrowRange $4 to a QuantizedType. def ConvertToQuantTypeFromAttrs : NativeCodeCall< @@ -61,15 +58,12 @@ def HasNotSameStaticShapes : Constraint, "op must h def CreateNoneValue : NativeCodeCall< "$_builder.create($0.getLoc(), $_builder.getNoneType(), $_builder.getUnitAttr())">; -// Checks if the value has only one user. -// TODO(karimnosseir): Move to a common place? -def HasOneUse : Constraint>; - //===----------------------------------------------------------------------===// // Nullary ops patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; +def LegalizeTFConstToTFLConst: Pat<(TF_ConstOp ElementsAttr:$value), + (TFL_ConstOp $value)>; // Convert to std constant for statically shaped, non-opaque constants. def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value), @@ -79,17 +73,23 @@ def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value), // Unary ops patterns. //===----------------------------------------------------------------------===// def IsDataFormatNHWC : ConstantAttr; + +// Constraint that Attr has values [1, X, Y, 1] def IsIntList1XY1 : AttrConstraint>; + +// Constraint that values in list attribute are all ones. def IsAllOnes : AttrConstraint>; + +// Constraint that attribute is string with value either "SAME" or "VALID" def IsSameOrValid : AttrConstraint< CPred<"$_self.cast().getValue() == \"SAME\" || " # "$_self.cast().getValue() == \"VALID\"">, "'SAME' or 'VALID' paddings">; -def : Pat<(TF_AbsOp $arg), (TFL_AbsOp $arg)>; -def : Pat<(TF_AddNOp $inputs), (TFL_AddNOp $inputs)>; +def LegalizeAbs : Pat<(TF_AbsOp $arg), (TFL_AbsOp $arg)>; +def LegalizeAddN : Pat<(TF_AddNOp $inputs), (TFL_AddNOp $inputs)>; -def : Pat<(TF_AvgPoolOp $value, +def LegalizeAveragePool : Pat<(TF_AvgPoolOp $value, IsIntList1XY1:$ksize, IsIntList1XY1:$strides, $padding, @@ -102,35 +102,42 @@ def : Pat<(TF_AvgPoolOp $value, /*stride_w=*/ExtractI32At<2>:$strides, /*fused_activation_function=*/TFL_AF_None)>; -def : Pat<(TF_ArgMaxOp $input, $dim), (TFL_ArgMaxOp $input, $dim)>; -def : Pat<(TF_ArgMinOp $input, $dim), (TFL_ArgMinOp $input, $dim)>; +def LegalizeArgMax : Pat<(TF_ArgMaxOp $input, $dim), + (TFL_ArgMaxOp $input, $dim)>; +def LegalizeArgMin : Pat<(TF_ArgMinOp $input, $dim), + (TFL_ArgMinOp $input, $dim)>; -def : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>; +def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>; -def : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>; +def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>; -def : Pat<(TF_EluOp $arg), (TFL_EluOp $arg)>; +def LegalizeElu : Pat<(TF_EluOp $arg), (TFL_EluOp $arg)>; -def : Pat<(TF_ExpandDimsOp $input, $dim), (TFL_ExpandDimsOp $input, $dim)>; +def LegalizeExpandDims : Pat<(TF_ExpandDimsOp $input, $dim), + (TFL_ExpandDimsOp $input, $dim)>; -def : Pat<(TF_FakeQuantWithMinMaxArgsOp $inputs, - $min, $max, - $num_bits, $narrow_range), - (TFL_DequantizeOp - (TFL_QuantizeOp $inputs, - (ConvertToQuantTypeFromAttrs $inputs, $min, $max, - $num_bits, $narrow_range)))>; +def LegalizeFakeQuantToDequantizeQuantize : Pat< + (TF_FakeQuantWithMinMaxArgsOp $inputs, $min, $max, $num_bits, $narrow_range), + (TFL_DequantizeOp + (TFL_QuantizeOp $inputs, + (ConvertToQuantTypeFromAttrs $inputs, $min, $max, + $num_bits, $narrow_range)))>; -def : Pat<(TF_FillOp $arg, $value), (TFL_FillOp $arg, $value)>; +def LegalizeFill : Pat<(TF_FillOp $arg, $value), (TFL_FillOp $arg, $value)>; -def : Pat<(TF_FloorOp $arg), (TFL_FloorOp $arg)>; +def LegalizeFloor : Pat<(TF_FloorOp $arg), (TFL_FloorOp $arg)>; -def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>; -def : Pat<(TF_LogOp $arg), (TFL_LogOp $arg)>; -def : Pat<(TF_LogicalNotOp $arg), (TFL_LogicalNotOp $arg)>; -def : Pat<(TF_LogSoftmaxOp $arg), (TFL_LogSoftmaxOp $arg)>; +def LegalizeLeakyRelu : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), + (TFL_LeakyReluOp $arg, $a)>; -def : Pat<(TF_MaxPoolOp $value, +def LegalizeLog : Pat<(TF_LogOp $arg), (TFL_LogOp $arg)>; + +def LegalizeNot : Pat<(TF_LogicalNotOp $arg), (TFL_LogicalNotOp $arg)>; + +def LegalizeLogSoftmax : Pat<(TF_LogSoftmaxOp $arg), (TFL_LogSoftmaxOp $arg)>; + +def LegalizeMaxPool2D : Pat< + (TF_MaxPoolOp $value, IsIntList1XY1:$ksize, IsIntList1XY1:$strides, $padding, @@ -143,8 +150,12 @@ def : Pat<(TF_MaxPoolOp $value, /*filter_height=*/ExtractI32At<1>:$ksize, /*fused_activation_function=*/TFL_AF_None)>; -def : Pat<(TF_MaximumOp $arg1, $arg2), (TFL_MaximumOp $arg1, $arg2)>; -def : Pat<(TF_MinimumOp $arg1, $arg2), (TFL_MinimumOp $arg1, $arg2)>; +def LegalizeMaximum : Pat<(TF_MaximumOp $arg1, $arg2), + (TFL_MaximumOp $arg1, $arg2)>; + +def LegalizeMinimum : Pat<(TF_MinimumOp $arg1, $arg2), + (TFL_MinimumOp $arg1, $arg2)>; + def : Pat<(TF_NegOp $arg), (TFL_NegOp $arg)>; def : Pat<(TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis), (TFL_OneHotOp $indices, $depth, $on_value, $off_value, diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index ab4c4f5c4cf..bfcbc190638 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite( return success(); } +// Converts any IntegerAttr to an IntegerAttr of an i32 type. +// The value won't change in the new attribute, but if the value is out of +// the bound of i32, the function returns a failure. +LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) { + if (attr.getType().isInteger(/*width=*/32)) { + *attr_i32 = attr; + return success(); + } + + int64_t value = attr.getInt(); + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + return failure(); + } + + *attr_i32 = IntegerAttr::get( + IntegerType::get(/*width=*/32, attr.getContext()), value); + return success(); +} + LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); @@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( // Extract axis attribute from constant axis tensor ElementsAttr axis; if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); + IntegerAttr axis_int = ExtractSingleElementAsInteger(axis); + + // "axis" operand could be a i64 tensor. Resolve it here. + IntegerAttr axis_i32; + if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure(); StringAttr fused_activation_function = StringAttr::get("NONE", rewriter.getContext()); rewriter.replaceOpWithNewOp( - op, output_type, values, ExtractSingleElementAsInteger(axis), - fused_activation_function); + op, output_type, values, axis_i32, fused_activation_function); return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc index 307a45639c5..5cb659fa318 100644 --- a/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc +++ b/tensorflow/compiler/mlir/lite/transforms/load_quantization_recipe.cc @@ -184,7 +184,7 @@ void LoadQuantizationRecipe::LoadForLSTMOp(LSTMOp lstm, OpBuilder* builder) { auto new_cell_tanh = builder->create(loc, int16, new_cell); auto hidden_state = builder->create( - loc, int16, new_cell_tanh.y(), output_gate->getResult(0), none_af); + loc, int16, new_cell_tanh.output(), output_gate->getResult(0), none_af); auto act = builder->create( loc, int8, hidden_state.output(), lstm.projection_weights(), lstm.projection_bias(), none_af, fc_format, keep_dims); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index a69b0a3c624..2498a732a86 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -577,7 +577,6 @@ struct ConvertTensorListResize ArrayRef({input_handle, input_shape, size_diff, size}), /*then_branch=*/rewriter.getSymbolRefAttr(then_branch_op), /*else_branch=*/rewriter.getSymbolRefAttr(else_branch_op), - /*output_shapes=*/rewriter.getArrayAttr({}), /*is_stateless=*/rewriter.getBoolAttr(true)); return success(); } @@ -838,7 +837,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( // TensorFlow operations that doesn't have operands and results of type // variant are legal. Here, we don't distinguish between variants encoding // TensorList or some other type as that information is not available here. - // This constraint should be relaxed to support other variant types in TFLite. + // Partial legalization is used below to still allow ops with variant types + // still. auto is_legal = [](Operation *op) { auto is_not_variant = [](Type ty) { return !ty.cast().getElementType().isa(); @@ -859,6 +859,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); // Register fused LSTM/RNN ops as legal. target.addLegalOp(); target.addLegalOp(); @@ -872,7 +873,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( ConvertTensorListPushBack, ConvertTensorListReserve, ConvertTensorListSetItem, ConvertTensorListStack, ConvertTensorListResize, ConvertWhile>(context); - return applyFullConversion(func, target, patterns); + return applyPartialConversion(func, target, patterns); } void LowerStaticTensorListPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index a1aedb0af32..30ae4b81f4f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -206,6 +206,28 @@ DenseElementsAttr GetShape(Value output_val) { llvm::makeArrayRef(shape)); } +static Type GetShapeStrippedType(TypeAttr type_attr) { + auto type = type_attr.getValue(); + auto shaped_type = type.dyn_cast(); + if (shaped_type) { + return shaped_type.getElementType(); + } else { + return type; + } +} + +bool NotFromQuantOpDifferentQuant(Value val, TypeAttr qtype_attr) { + auto val_defn_op = val.getDefiningOp(); + TFL::QuantizeOp q_op = llvm::dyn_cast_or_null(val_defn_op); + if (!q_op) return true; + + // Ignore shape details - weŕe really only trying to + // check if quantization is the same. + auto stripped_src_qtype = GetShapeStrippedType(q_op.qtypeAttr()); + auto stripped_qtype = GetShapeStrippedType(qtype_attr); + return stripped_src_qtype == stripped_qtype; +} + #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" // Fuse Add with proceeding FullyConnected. @@ -641,8 +663,8 @@ struct ConvertTrivialTransposeOpToReshapeOp LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op, PatternRewriter &rewriter) const override { - auto input_type = transpose_op.x().getType().cast(); - auto output_type = transpose_op.y().getType().cast(); + auto input_type = transpose_op.input().getType().cast(); + auto output_type = transpose_op.output().getType().cast(); // It's possible to know if the transformation is safe only if the input // & output shapes are fully known and permutation is a constant. if (!input_type.hasStaticShape() || !output_type.hasStaticShape()) @@ -691,7 +713,8 @@ struct ConvertTrivialTransposeOpToReshapeOp auto new_shape = rewriter.create(loc, new_shape_attr); rewriter.replaceOpWithNewOp( - transpose_op, transpose_op.y().getType(), transpose_op.x(), new_shape); + transpose_op, transpose_op.output().getType(), transpose_op.input(), + new_shape); return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index a3244f31053..e8f1c9c2cf3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -29,6 +29,14 @@ def ExtractSingleElementAsFloat : NativeCodeCall< // Checks if the value has only one user. def HasOneUse : Constraint>; +// Checks if the value has rank at most 'n'. +class HasRankAtMost : Constraint< + CPred<"$0.getType().cast().getRank() <= " # n>>; + +// Checks value is not produce by a TLF_QUant with +// different quantization attribute +def NotFromQuantOpDifferentQuant : Constraint>; + //===----------------------------------------------------------------------===// // Ternary ops patterns. //===----------------------------------------------------------------------===// @@ -160,7 +168,10 @@ foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in // This pattern applies when the same quantize/dequantize have been used twice // with the same scale. We want to remove the redundancy. // TODO(fengliuai): move this to the sanity check of pre-quantize pass. -def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>; +def eliminate_dq_q_pairs : Pat< + (TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), + (replaceWithValue $in), + [(NotFromQuantOpDifferentQuant $in, $qt)]>; // Constraint that makes sure both operands are the same operands. @@ -347,7 +358,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { // The result of the new "BinaryOp" will have the same shape as // `input`. In other words, the shape of the `Reshape` op are not // changed after the transformation. - (IsTailOfShape $rhs, $input)]>; + (IsTailOfShape $rhs, $input), + (HasRankAtMost<5> $input), + (HasRankAtMost<5> $rhs)]>; } foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 959c17e317a..105c9394fb4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -76,7 +76,7 @@ std::unique_ptr> CreateOptimizeFunctionalOpsPass(); // Creates an instance of the TensorFlow Lite dialect pass to add default // quantization parameters. std::unique_ptr> CreateDefaultQuantParamsPass( - double default_min, double default_max); + double default_min, double default_max, bool is_signed); // Creates an instance of the TensorFlow Lite dialect pass to convert dense // tensor to sparse format. diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 97b7d57dbf4..9a1da0ad03d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -118,6 +119,24 @@ void RemoveQuantizationAdaptorOps(FuncOp func) { func.setType(new_func_type); } +// Remove the back-to-back quantize and dequantize ops with volatile attribute. +struct RemoveVolatileOps : public OpRewritePattern { + explicit RemoveVolatileOps(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(DequantizeOp op, + PatternRewriter& rewriter) const override { + auto input_op = op.input().getDefiningOp(); + if (auto q = llvm::dyn_cast_or_null(input_op)) { + if (!q.getAttr(mlir::quant::kVolatileOpAttrName)) return failure(); + + op.replaceAllUsesWith(q.input()); + return success(); + } + return failure(); + } +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc" void PostQuantizePass::runOnFunction() { @@ -125,11 +144,15 @@ void PostQuantizePass::runOnFunction() { auto func = getFunction(); auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, &patterns); + patterns.insert>(ctx); applyPatternsAndFoldGreedily(func, patterns); if (!emit_quant_adaptor_ops_) { RemoveQuantizationAdaptorOps(getFunction()); } + + patterns.insert(ctx); + applyPatternsAndFoldGreedily(func, patterns); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 6179eb2ce64..56af68f6bbe 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -41,15 +41,22 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +// The cmd line flag to turn on/off Tf.Text API fusion. // NOLINTNEXTLINE +static llvm::cl::opt fuse_tftext( + "tfl-fuse-tftext", llvm::cl::value_desc("bool"), + llvm::cl::desc("Fuse TF.Text API ops when it's true"), + llvm::cl::init(false)); namespace mlir { namespace TFL { namespace { constexpr char kTFAPIImplements[] = "tf.api_implements"; +constexpr char kTfTextAPIPRefix[] = "tftext:"; // Abstracts the conversion of the embedded lookup composite function. class ConvertEmbeddedLookupFunc { @@ -187,6 +194,10 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func, OpBuilder builder(func.getBody()); if (failed(ConvertKerasLSTMLayer(func, &builder))) return signalPassFailure(); + } else if (fuse_tftext && attr.getValue().startswith(kTfTextAPIPRefix)) { + if (failed(ConvertTFTextAPI(func, attr.getValue()))) { + return signalPassFailure(); + } } } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index a9e10a485bf..579063f9c9d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -70,6 +70,7 @@ class PrepareQuantizePass : public PassWrapper { public: // Constructor used by the PassRegistration and enforce uint8 quantization. + // This is only used by test. explicit PrepareQuantizePass() { if (quantize_signed) quant_specs_.inference_type = tensorflow::DT_QINT8; @@ -108,8 +109,8 @@ class PrepareQuantizePass // Get the min and max values from the quantization specification for the // current function function and argument index. Uses default values if // the function is specified in the `quantize_whitelist`. - std::pair GetMinMaxValuesForArgument( - llvm::StringRef func_name, int index) { + std::pair, llvm::Optional> + GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) { if (func_name == quant_specs_.target_func) { return quant_specs_.input_ranges[index]; } else { @@ -159,10 +160,14 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { } auto min_max = GetMinMaxValuesForArgument(func_name, i); + // The input min/max or mean/std are not specified, then skip. + if (!min_max.first.hasValue() || !min_max.second.hasValue()) return; + TypeAttr params = quant::GetQuantizedTypeAttr( - builder, input_type, builder.getF64FloatAttr(min_max.first), - builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits, - narrow_range, is_signed); + builder, input_type, + builder.getF64FloatAttr(min_max.first.getValue()), + builder.getF64FloatAttr(min_max.second.getValue()), + /*quant_dim=*/-1, num_bits, narrow_range, is_signed); builder.setInsertionPoint(block, insertion_point); auto q_op = builder.create(loc, params.getValue(), arg); @@ -205,6 +210,7 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) { // one is returned directly, we decide to return the quantized result instead, // so this op can be quantized. This is only applied on the returned result // because the error will not be accumulated. + func.walk([&](ReturnOp ret) { int i = 0; for (Value returned : ret.operands()) { @@ -232,6 +238,51 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) { "Missing quantization parameter on the output might introduce " "quantization error!"); }); + + // Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be + // eliminated at this point. This only occurs for the pattern + // (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA + // where the qdq pair denotes a non-trivial requantiziion of an + // alreadyquantized value. Since this makes little sense (directly quantizing + // (Quant $in, $qA) would introduce less quantization noise) the likley cause + // is an minor error in constructing the original network model that + // introduced back-to-back Fake Quantization operations. Hence: emit a + // warning. N.b. at this point weŕe (teporarility) in the quantization dialect + // (presuambly enalbe re-use in xla etc) quant::*QuantizeCastOp weŕe matching + // here. + // + func.walk([&](quant::QuantizeCastOp q_op) { + // If up with end up with + auto dq_op = dyn_cast_or_null( + q_op.getOperand().getDefiningOp()); + if (!dq_op) { + return; + } + auto dq_arg = dq_op.getOperand(); + + if (!dq_arg.hasOneUse()) { + // The initial quanization is used sompleace else ... so it might be + // reasonable for it to requantized for another purpose. + // TODO: ideally would want to still check whether requanization narrows + // rather than widens the representation + return; + } + + // Invariant: + // isa(dq_arg.getDefiningOp()) --> + // getdq_arg.getType() != q_op.getResult().getType() + // + // as otherwise qdq pair would have been optimized away. + auto qd_arg_def_q_op = + dyn_cast_or_null(dq_arg.getDefiningOp()); + if (!qd_arg_def_q_op) { + return; + } + + qd_arg_def_q_op.emitWarning() + << " quantizer's output has another quantizer (" << q_op.getLoc() + << ") as consumer - intentional?"; + }); } using PrepareQuantStats = @@ -257,15 +308,16 @@ void PrepareQuantizePass::runOnFunction() { // convert all of them to signed. OwningRewritePatternList patterns; bool is_signed = quant_specs_.IsSignedInferenceType(); + int bit_width = quant_specs_.GetQuantizationTypeWidth(); if (is_signed) { patterns.insert>(ctx); // Convert quant stats to int8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert(8, false, true, ctx); + patterns.insert(bit_width, false, true, ctx); } else { // Convert quant stats to uint8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert(8, false, false, ctx); + patterns.insert(bit_width, false, false, ctx); } applyPatternsAndFoldGreedily(func, patterns); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index c5211bdfadb..3310c521a5a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -82,13 +82,48 @@ class PrepareTFPass : public PassWrapper { bool unfold_batch_matmul_; }; +template +struct FetchConstantMinMaxInputs { + using AttrType = DenseFPElementsAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + Value min = tf_op.min(), max = tf_op.max(); + + // TODO: incomplete neither IdentityN ops + // nor chains of Identity* (not rare) are handled + if (auto id1 = dyn_cast_or_null(min.getDefiningOp())) + min = id1.input(); + if (auto id2 = dyn_cast_or_null(max.getDefiningOp())) + max = id2.input(); + if (!matchPattern(min, m_Constant(&min_value))) { + return false; + } + if (!matchPattern(max, m_Constant(&max_value))) { + return false; + } + return true; // Succesfully matched and fetched. + } +}; + +template +struct FetchMinMaxAttrs { + using AttrType = FloatAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + min_value = tf_op.minAttr(); + max_value = tf_op.maxAttr(); + return true; // Succesfully matched and fetched. + } +}; + // TODO(fengliuai): move this rule to PreparePatterns.td // TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass. // TODO(b/140968741): propagate the sign from the command line. Currently all // the FakeQuant is assumed to targeting UIN8, but per-channel kernel is // actually INT8. // Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the -// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant +// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op +// to be constant folded. Since the constant // folding logic will use a "std.constant" op to replace the // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to @@ -112,33 +147,49 @@ class PrepareTFPass : public PassWrapper { // | // tf.dequantize // | -template +// +// +// Warns if the (most likely unwanted, currently not quite correctly handled) +// case of back-to-back tf.FakeQuant occurs +// +// tf.FakeQuant* +// | +// tf.FakeQuant* +// +// tf.identity / tf.IdentityN between the tf.FakeQuant* ops +// need no special treatment are already eliminated before the rewrites / check +// is applied. +// + +template struct InsertTFLQuantOpsAfterTFFakeQuantOp : public OpRewritePattern { - using BaseType = InsertTFLQuantOpsAfterTFFakeQuantOp; + using BaseType = + InsertTFLQuantOpsAfterTFFakeQuantOp; - explicit InsertTFLQuantOpsAfterTFFakeQuantOp( - MLIRContext *ctx) + explicit InsertTFLQuantOpsAfterTFFakeQuantOp(MLIRContext *ctx) : OpRewritePattern(ctx) {} + FetchMinMax fetchMinMax; + + using FetchAttrType = typename FetchMinMax::AttrType; LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, PatternRewriter &rewriter) const override { // We don't want to insert quantize/dequantize if the quantize op exists. auto res = tf_op.outputs(); - if (!res.hasOneUse() || isa(*res.user_begin())) + if (!res.hasOneUse() || isa(*res.user_begin())) { return failure(); + } // Extract the min/max constant values from the operands. We also consider // a special case that there are tf.Identity ops between the min/max // constants and the tf.FakeQuantWithMinMaxVarsOp. - Value min = tf_op.min(), max = tf_op.max(); - DenseFPElementsAttr min_value, max_value; - if (auto id1 = dyn_cast_or_null(min.getDefiningOp())) - min = id1.input(); - if (auto id2 = dyn_cast_or_null(max.getDefiningOp())) - max = id2.input(); - if (!matchPattern(min, m_Constant(&min_value))) return failure(); - if (!matchPattern(max, m_Constant(&max_value))) return failure(); + + FetchAttrType min_value, max_value; + if (!fetchMinMax(tf_op, min_value, max_value)) { + return failure(); + } int quant_dim = -1; if (PerAxis) { @@ -155,7 +206,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp TypeAttr qtype = quant::GetQuantizedTypeAttr( rewriter, res_type, min_value, max_value, quant_dim, num_bits, narrow_range, /*is_signed=*/false); - if (!qtype) failure(); + if (!qtype) { + return failure(); + } // Finally, use the quantization parameter to create the quantize and // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp @@ -172,12 +225,22 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp } }; -using PreparePerTensorFakeQuant = - InsertTFLQuantOpsAfterTFFakeQuantOp; +// +// Three instances of the rule to cover the three different types of +// TF::FakeQuant operators +// +using PreparePerTensorFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp< + TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false, + FetchConstantMinMaxInputs>; -using PreparePerChannelFakeQuant = - InsertTFLQuantOpsAfterTFFakeQuantOp; +using PreparePerChannelFakeQuant = InsertTFLQuantOpsAfterTFFakeQuantOp< + TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true, + FetchConstantMinMaxInputs>; + +using PreparePerTensorFakeQuantWithMinMaxArgs = + InsertTFLQuantOpsAfterTFFakeQuantOp< + TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false, + FetchMinMaxAttrs>; // Templated class for declaring a converter from some TensorFlow convolution // op into its counterpart in TensorFlow Lite. @@ -644,9 +707,10 @@ void PrepareTFPass::runOnFunction() { // This pattern was intented to uses TFL QDQs to preserve the quantization // parameters from the TF Quant ops, thus this pattern should run with the - // first `applyPatternsAndFoldGreedily` method, which would otherwise removes - // the TF FakeQuant ops by the constant folding. - patterns.insert(ctx); + // first `applyPatternsGreedily` method, which would otherwise removes the + // TF FakeQuant ops by the constant folding. + patterns.insert(ctx); // This pattern will try to identify and optimize for dilated convolution. // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 5df57de6f71..081ba7ac6e7 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/test.h" namespace mlir { @@ -92,7 +93,9 @@ class LstmUtilsTest : public ::testing::Test { LstmUtilsTest() {} void SetUp() override { - builder_ = std::unique_ptr(new Builder(&context_)); + RegisterDialects(); + context_ = std::make_unique(); + builder_ = std::unique_ptr(new Builder(context_.get())); fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false); fused_lstm_func_cifg_ = createLstmCompositeFunc(builder_.get(), false, true); @@ -105,10 +108,17 @@ class LstmUtilsTest : public ::testing::Test { fused_ln_lstm_func_.erase(); builder_.reset(); } + + void RegisterDialects() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + } + FuncOp fused_lstm_func_; FuncOp fused_lstm_func_cifg_; FuncOp fused_ln_lstm_func_; - mlir::MLIRContext context_; + std::unique_ptr context_; std::unique_ptr builder_; }; diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc new file mode 100644 index 00000000000..12929152d1e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFL { + +namespace { + +constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer"; +constexpr char kTFAPIImplements[] = "tf.api_implements"; + +inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) { + std::string content = ""; + ShapedType type = RankedTensorType::get( + {static_cast(content.size())}, builder->getIntegerType(8)); + return OpaqueElementsAttr::get( + builder->getContext()->getRegisteredDialect("tfl"), type, content); +} + +inline RankedTensorType getInputType(mlir::FuncOp func, int idx) { + return func.getType() + .getInput(idx) + .dyn_cast_or_null(); +} + +inline RankedTensorType getResultType(mlir::FuncOp func, int idx) { + return func.getType() + .getResult(idx) + .dyn_cast_or_null(); +} + +LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { + if (func.getNumResults() != 2) { + return failure(); + } + if (func.getNumArguments() != 1) { + return failure(); + } + auto input_type = getInputType(func, 0); + if (!input_type || input_type.getRank() != 1 || + !input_type.getElementType().isa()) { + return failure(); + } + auto value_type = getResultType(func, 0); + if (!value_type || value_type.getRank() != 1 || + !value_type.getElementType().isa()) { + return failure(); + } + auto offset_type = getResultType(func, 1); + if (offset_type.getRank() != 1 || + !offset_type.getElementType().isInteger(64)) { + return failure(); + } + return success(); +} + +LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, + llvm::StringRef api) { + func.eraseBody(); + func.addEntryBlock(); + func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext())); + + Value text = func.getArgument(0); + auto output_type = func.getType().getResult(0); + auto offset_type = func.getType().getResult(1); + SmallVector shape = {output_type, offset_type}; + ArrayRef output_types(shape); + + OpBuilder builder(func.getBody()); + + auto op = builder.create(func.getLoc(), output_types, + ValueRange(text), api, + emptyCustomOption(&builder)); + + builder.create(func.getLoc(), op.getResults()); + return success(); +} +} // namespace + +LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) { + if (api.str() == kWhitespaceTokenizer) { + if (succeeded(VerifyWhitespaceTokenizer(func))) { + return ConvertWhitespaceTokenizer(func, api); + } + } + return failure(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h new file mode 100644 index 00000000000..283e57c179a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api); + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 11d3e7332db..b2225ec1c4f 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_os_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { static void RegisterDialects() { static bool init_once = []() { mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); - mlir::registerDialect(); return true; }(); (void)init_once; diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 666f89ac72f..1189a926383 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -12,6 +12,22 @@ cc_library( "//tensorflow/c:tf_status_helper", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:error_util", + # (yongtang) The graph_optimization_pass_registration needs to be part + # of a shared object that will be loaded whenever `import tensorflow` + # is run. The natural place is libtensorflow_framework.so. + # While adding graph_optimization_pass_registration to + # libtensorflow_framework.so is possible with some modification in + # dependency, many tests will fail due to multiple copies of LLVM. + # See https://github.com/tensorflow/tensorflow/pull/39231 for details. + # Alternatively, we place graph_optimization_pass_registration here + # because: + # - tensorflow/python/_pywrap_mlir.so already depends on LLVM anyway + # - tensorflow/python/_pywrap_mlir.so always loaded as part of python + # binding + # TODO: It might be still preferrable to place graph_optimization_pass + # as part of the libtensorflow_framework.so, as it is the central + # place for core related components. + "//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration", "//tensorflow/compiler/mlir/tensorflow:import_utils", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index d0f6e015922..f22fb519a64 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context; - auto module_or = ConvertSavedModelV1ToMlir(bundle, &context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); return "// error"; diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD new file mode 100644 index 00000000000..78f4312da46 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -0,0 +1,41 @@ +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") + +package(licenses = ["notice"]) + +tf_python_pybind_extension( + name = "mlir_wrapper", + srcs = [ + "attrs.cc", + "basic_classes.cc", + "builders.cc", + "mlir_wrapper.cc", + "mlir_wrapper.h", + "ops.cc", + "types.cc", + ], + module_name = "mlir_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@pybind11", + ], +) + +tf_python_pybind_extension( + name = "filecheck_wrapper", + srcs = ["filecheck_wrapper.cc"], + module_name = "filecheck_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@pybind11", + ], +) diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc new file mode 100644 index 00000000000..ca7faf2e1d3 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_attrs(py::module& m) { + py::class_(m, "Attribute"); + py::class_(m, "IntegerAttr") + .def("get", + py::overload_cast(&mlir::IntegerAttr::get)); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc new file mode 100644 index 00000000000..25adb44fe1d --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_basic_classes(py::module& m) { + py::class_(m, "MLIRContext").def(py::init<>()); + + py::class_(m, "Location"); + + py::class_(m, "UnknownLoc") + .def("get", &mlir::UnknownLoc::get); + + py::class_(m, "Region") + .def("back", &mlir::Region::back, py::return_value_policy::reference) + .def("front", &mlir::Region::front, py::return_value_policy::reference) + .def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); }) + .def("push_back", &mlir::Region::push_back) + .def("size", [](mlir::Region& r) { return r.getBlocks().size(); }) + .def("front", &mlir::Region::front, py::return_value_policy::reference); + py::class_(m, "Block_Iterator"); + py::class_(m, "Block") + .def("new", ([]() { return new mlir::Block; }), + py::return_value_policy::reference) + .def("end", &mlir::Block::end) + .def("addArgument", &mlir::Block::addArgument); + + py::class_(m, "Value").def("getType", &mlir::Value::getType); + py::class_(m, "OpResult"); + py::class_(m, "BlockArgument"); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc new file mode 100644 index 00000000000..338f17ed6df --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Builders.h" // from @llvm-project + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_builders(py::module& m) { + py::class_(m, "Builder") + .def(py::init()) + .def("getFunctionType", + [](mlir::Builder& b, std::vector inputs, + std::vector outputs) { + return b.getFunctionType(llvm::ArrayRef(inputs), + llvm::ArrayRef(outputs)); + }); + py::class_(m, "OpBuilder") + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc) + .def("setInsertionPoint", + py::overload_cast( + &mlir::OpBuilder::setInsertionPoint)) + .def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint) + .def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint) + .def( + "createOperation", + [](mlir::OpBuilder& opb, mlir::OperationState& state) { + return opb.createOperation(state); + }, + py::return_value_policy::reference) + .def("getContext", &mlir::OpBuilder::getContext, + py::return_value_policy::reference); + + py::class_(m, "OpBuilder_InsertionPoint") + .def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc new file mode 100644 index 00000000000..8a841856b72 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "llvm/Support/SourceMgr.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(filecheck_wrapper, m) { + m.def("check", [](std::string input, std::string check) { + llvm::FileCheckRequest fcr; + llvm::FileCheck fc(fcr); + llvm::SourceMgr SM = llvm::SourceMgr(); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), + llvm::SMLoc()); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check), + llvm::SMLoc()); + llvm::Regex regex = fc.buildCheckPrefixRegex(); + fc.readCheckFile(SM, llvm::StringRef(check), regex); + return fc.checkInput(SM, llvm::StringRef(input)); + }); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc new file mode 100644 index 00000000000..6f468cd4267 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(mlir_wrapper, m) { + m.def("registerDialects", []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + }); + + init_basic_classes(m); + init_types(m); + init_builders(m); + init_ops(m); + init_attrs(m); +} diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/dialect_static_registration.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h similarity index 57% rename from tensorflow/compiler/mlir/tfrt/runtime_fallback/dialect_static_registration.cc rename to tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h index 7632b0546fa..562c59b43e1 100644 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/dialect_static_registration.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h @@ -13,19 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -//===- dialect_static_registration.cc -------------------------------------===// -// -// This file registers the RuntimeFallbackDialect. -// -//===----------------------------------------------------------------------===// +#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H +#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H -#include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" -namespace mlir { -namespace tfd { +namespace py = pybind11; -// Static initialization for dialect registration. -static DialectRegistration tfd_registration; +void init_basic_classes(py::module& m); +void init_types(py::module& m); +void init_builders(py::module& m); +void init_ops(py::module& m); +void init_attrs(py::module& m); -} // namespace tfd -} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc new file mode 100644 index 00000000000..4432829653e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc @@ -0,0 +1,194 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project + +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +void init_ops(py::module& m) { + py::class_>( + m, "Operation") + .def("getRegion", &mlir::Operation::getRegion, + py::return_value_policy::reference) + .def("getResult", &mlir::Operation::getResult) + .def("dump", &mlir::Operation::dump) + .def("getNumResults", &mlir::Operation::getNumResults); + + py::class_(m, "OperationState") + .def(py::init([](mlir::Location loc, std::string name) { + return mlir::OperationState(loc, llvm::StringRef(name)); + })) + .def("addTypes", + [](mlir::OperationState& state, std::vector tys) { + state.addTypes(mlir::ArrayRef(tys)); + }) + .def("addOperands", + [](mlir::OperationState& os, std::vector ops) { + os.addOperands(mlir::ArrayRef(ops)); + }) + .def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion), + py::return_value_policy::reference); + + py::class_(m, "ModuleOp") + .def("create", + [](mlir::Location loc) { return mlir::ModuleOp::create(loc); }) + .def("push_back", + [](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); }) + .def("dump", &mlir::ModuleOp::dump) + .def("getAsStr", [](mlir::ModuleOp& m) { + std::string str; + llvm::raw_string_ostream os(str); + m.print(os); + return os.str(); + }); + + py::class_(m, "FuncOp") + .def("create", + [](mlir::Location location, std::string name, + mlir::FunctionType type) { + auto func = mlir::FuncOp::create(location, name, type); + func.addEntryBlock(); + return func; + }) + .def( + "getBody", + [](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); }, + py::return_value_policy::reference) + .def("getArguments", + [](mlir::FuncOp& f) { return f.getArguments().vec(); }) + .def("getName", [](mlir::FuncOp& f) { return f.getName().str(); }) + .def("getType", &mlir::FuncOp::getType); + + py::class_(m, "ReturnOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector values) -> mlir::Operation* { + return opb + .create(loc, + mlir::ArrayRef(values)) + .getOperation(); + }); + + // mlir::TF::AddOp + py::class_(m, "Tf_AddV2Op") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + py::class_(m, "Tf_AnyOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input, + mlir::Value reduction_indices, + bool keep_dims = false) -> mlir::Operation* { + return opb + .create(loc, opb.getI1Type(), input, + reduction_indices, keep_dims) + .getOperation(); + }); + + // mlir::TF::ConstOp + py::class_(m, "Tf_ConstOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Attribute value) -> mlir::Operation* { + return opb.create(loc, value).getOperation(); + }); + + // mlir::TF::EqualOp + py::class_(m, "Tf_EqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb + .create(loc, x, y, opb.getBoolAttr(true)) + .getOperation(); + }); + + // mlir::TF::GreaterEqualOp + py::class_(m, "Tf_GreaterEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y) + .getOperation(); + }); + + // mlir::TF::GreaterOp + py::class_(m, "Tf_GreaterOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LegacyCallOp + py::class_(m, "Tf_LegacyCallOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector output, std::vector args, + std::string f) -> mlir::Operation* { + return opb + .create( + loc, mlir::ArrayRef(output), + mlir::ArrayRef(args), mlir::StringRef(f)) + .getOperation(); + }); + + // mlir::TF::LessEqualOp + py::class_(m, "Tf_LessEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LessOp + py::class_(m, "Tf_LessOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::NegOp + py::class_(m, "Tf_NegOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Value x) -> mlir::Operation* { + return opb.create(loc, x).getOperation(); + }); + + py::class_(m, "Tf_NotEqualOp") + .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) { + return opb + .create( + loc, x, y, mlir::BoolAttr::get(true, opb.getContext())) + .getOperation(); + }); + + // mlir::TF::SubOp + py::class_(m, "Tf_SubOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc new file mode 100644 index 00000000000..2be67f8e93e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +void init_types(py::module& m) { + // Type + py::class_ Type(m, "Type"); + Type.def("getKind", &mlir::Type::getKind); + + // Type Enums + py::enum_(Type, "StandardTypes_Kind") + .value("BF16", mlir::StandardTypes::BF16); + + // Type Sub-classes + py::class_(m, "FunctionType") + .def("getResults", + [](mlir::FunctionType& ft) { return ft.getResults().vec(); }); + + py::class_(m, "FloatType") + .def("get", &mlir::FloatType::get); + + py::class_(m, "IntegerType") + .def("get", py::overload_cast( + &mlir::IntegerType::get)); + + py::class_(m, "UnrankedTensorType") + .def("get", &mlir::UnrankedTensorType::get); + + py::class_(m, "RankedTensorType") + .def("get", [](std::vector shape, mlir::Type ty) { + return mlir::RankedTensorType::get(mlir::ArrayRef(shape), ty); + }); +} diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 6d3131a781c..f1271d0da24 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [ ] tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', - 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', - 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', - 'xla-opt' + 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', + 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', + 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 661e6200df3..3e7596c75d7 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir', 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/tensorflow', + 'tensorflow/compiler/mlir/tfjs', 'tensorflow/compiler/mlir/xla', 'tensorflow/compiler/aot', 'tensorflow/compiler/xla/service/mlir_gpu', diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 9099f2be2e1..05b2f891676 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -36,7 +36,7 @@ filegroup( "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -224,7 +224,10 @@ cc_library( hdrs = [ "ir/tf_attributes.h", ], - deps = ["@llvm-project//mlir:IR"], + deps = [ + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], ) cc_library( @@ -427,6 +430,7 @@ cc_library( "transforms/parallel_execute_to_islands.cc", "transforms/promote_resources_to_args.cc", "transforms/raise_control_flow.cc", + "transforms/readonly_references_to_resources.cc", "transforms/replicate_invariant_op_hoisting.cc", "transforms/replicate_to_island.cc", "transforms/resource_device_inference.cc", @@ -448,6 +452,7 @@ cc_library( "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", + "transforms/tpu_space_to_depth_pass.cc", "transforms/tpu_variable_runtime_reformatting.cc", "translate/breakup-islands.cc", "translate/control_to_executor_dialect.cc", @@ -555,8 +560,7 @@ cc_library( srcs = ["ir/dialect_registration.cc"], deps = [ ":tensorflow", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", + "@llvm-project//mlir:Shape", ], alwayslink = 1, ) @@ -785,6 +789,9 @@ cc_library( name = "convert_type", srcs = ["utils/convert_type.cc"], hdrs = ["utils/convert_type.h"], + visibility = [ + "//visibility:public", + ], deps = [ ":tensorflow_types", "//tensorflow/core:framework", @@ -823,6 +830,7 @@ cc_library( ":mangling_util", ":tensorflow_attributes", ":tensorflow_types", + "//tensorflow/compiler/xla:util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1074,7 +1082,7 @@ genrule( srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", "ir/tf_generated_ops.td", "ir/tf_op_base.td", @@ -1139,6 +1147,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", @@ -1277,6 +1286,7 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -1291,6 +1301,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/protobuf/tpu:topology_proto_cc", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index ff1620347f7..f7b88317cd4 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -49,7 +49,7 @@ namespace TF { namespace { constexpr int64_t kUnknownResourceId = -1; -constexpr char kResourceArgUniqueIdAttr[] = "tf.resource_arg_unique_id"; +constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; // Returns if a VarHandleOp is anonymous, which means it always creates a new // variable. diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD new file mode 100644 index 00000000000..9528874f419 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -0,0 +1,55 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_cuda_library", + "tfe_xla_copts", +) + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + packages = ["//tensorflow/..."], +) + +tf_cuda_library( + name = "mlir_c_api", + srcs = [ + "c_api_unified_experimental_mlir.cc", + ], + copts = tf_copts() + tfe_xla_copts(), + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "mlir_c_api_registration", + srcs = ["c_api_unified_experimental_mlir_registration.cc"], + deps = [ + ":mlir_c_api", + "//tensorflow/c/eager:c_api_unified_internal", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc new file mode 100644 index 00000000000..0e8b7fedd9b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -0,0 +1,493 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" + +namespace mlir { +namespace TF { +using tensorflow::internal::AbstractFunction; +using tensorflow::internal::AbstractOp; +using tensorflow::internal::AbstractTensor; +using tensorflow::internal::dyncast; +using tensorflow::internal::ExecutionContext; +using tensorflow::internal::OutputList; + +namespace { + +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + return true; + }(); + (void)init_once; +} + +Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder, + Type* type) { + Status s = tensorflow::ConvertDataType(dtype, builder, type); + if (s.ok()) *type = UnrankedTensorType::get(*type); + return s; +} + +class MlirTensor : public AbstractTensor { + public: + explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {} + + Value getValue() { return value_; } + + static constexpr AbstractTensorKind kKind = kMlirTensor; + + private: + Value value_; +}; + +class MlirAbstractOp : public AbstractOp { + public: + explicit MlirAbstractOp(MLIRContext* context) + : AbstractOp(kKind), context_(context) {} + + void SetOpType(const char* op_type, TF_Status* s) override; + + void SetAttrType(const char* attr_name, TF_DataType dtype, + TF_Status* s) override; + + void SetOpName(const char* const op_name, TF_Status* s) override; + + MLIRContext* GetContext() { return context_; } + + Type AddRef(Type type, TF_Status* s); + + OperationState* Create(ArrayRef operands, TF_Status* s); + + static constexpr AbstractOpKind kKind = kMlirOp; + + private: + MLIRContext* context_; + llvm::StringMap attrs_; + std::unique_ptr state_; + const char* op_name_ = nullptr; +}; + +// MlirFunction is a thin wrapper over a FuncOp. +class MlirFunction : public AbstractFunction { + public: + explicit MlirFunction(std::unique_ptr context, + OwningModuleRef module, FuncOp func) + : AbstractFunction(kKind), + context_(std::move(context)), + module_(std::move(module)), + func_(func) {} + + TF_Function* GetTfFunction(TF_Status* s) override; + + static constexpr AbstractFunctionKind kKind = kGraphFunc; + + private: + std::unique_ptr context_; + OwningModuleRef module_; + FuncOp func_; +}; + +class MlirFunctionContext : public ExecutionContext { + public: + explicit MlirFunctionContext(const char* name) + : ExecutionContext(kKind), + context_(std::make_unique()), + builder_(context_.get()) { + // TODO(aminim) figure out the location story here + module_ = ModuleOp::create(builder_.getUnknownLoc()); + func_ = FuncOp::create(builder_.getUnknownLoc(), name, + builder_.getFunctionType(llvm::None, llvm::None)); + module_->push_back(func_); + builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock()); + } + + AbstractOp* CreateOperation() override { + return new MlirAbstractOp(context_.get()); + } + + void ExecuteOperation(AbstractOp* abstract_op, int num_inputs, + AbstractTensor* const* inputs, OutputList* o, + TF_Status* s) override; + + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override; + + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override; + + void RegisterFunction(AbstractFunction* func, TF_Status* s) override { + s->status = tensorflow::errors::Unimplemented( + "Registering graph functions has not been implemented yet."); + } + + static constexpr ExecutionContextKind kKind = kMlirContext; + + private: + std::unique_ptr context_; + OpBuilder builder_; + FuncOp func_; + OwningModuleRef module_; +}; + +void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) { + if (state_) { + s->status = tensorflow::errors::FailedPrecondition( + "SetOpType called on already built op."); + return; + } + std::string name = "tf."; + name += op_type; + // TODO(aminim) figure out the location story here + state_ = std::make_unique(UnknownLoc::get(context_), name); +} + +void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype, + TF_Status* s) { + if (!state_) { + s->status = tensorflow::errors::FailedPrecondition( + "op_type must be specified before specifying attrs."); + return; + } + Type mlir_type; + Builder builder(context_); + s->status = ConvertDataTypeToTensor(static_cast(dtype), + builder, &mlir_type); + if (!s->status.ok()) return; + attrs_[attr_name] = TypeAttr::get(mlir_type); +} + +void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) { + // TODO(aminim): should we use a location? + if (op_name_) { + s->status = tensorflow::errors::FailedPrecondition( + "SetOpName called on already built op."); + return; + } + op_name_ = op_name; +} + +Type MlirAbstractOp::AddRef(Type type, TF_Status* s) { + Type elt_type = getElementTypeOrSelf(type); + if (elt_type.isa()) { + s->status = tensorflow::errors::InvalidArgument( + "Requested reference to a reference type"); + return nullptr; + } + elt_type = TensorFlowRefType::get(elt_type); + if (RankedTensorType tensor_type = type.dyn_cast()) { + return RankedTensorType::get(tensor_type.getShape(), elt_type); + } + return UnrankedTensorType::get(elt_type); +} + +OperationState* MlirAbstractOp::Create(ArrayRef operands, TF_Status* s) { + state_->operands = llvm::to_vector<4>(operands); + const tensorflow::OpDef* op_def; + auto node_name = state_->name.getStringRef().drop_front( + TensorFlowDialect::getDialectNamespace().size() + 1); + s->status = + tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def); + if (!s->status.ok()) return nullptr; + Builder builder(context_); + // Process operands according to the op_def and infer derived attributes. + int current_operand = 0; + for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) { + if (!input_arg.number_attr().empty()) { + // TODO(b/156122856): we don't support variadic operands. + s->status = tensorflow::errors::Unimplemented( + "Unsupported 'number_attr' for '", input_arg.number_attr(), "'"); + return nullptr; + } else if (!input_arg.type_list_attr().empty()) { + s->status = tensorflow::errors::InvalidArgument( + "Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'"); + return nullptr; + } + if (current_operand >= operands.size()) { + s->status = tensorflow::errors::InvalidArgument("Missing operand for '", + input_arg.name(), "'"); + return nullptr; + } + Type expected_type; + if (input_arg.type() != tensorflow::DT_INVALID) { + s->status = + ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type); + if (!s->status.ok()) return nullptr; + if (input_arg.is_ref()) expected_type = AddRef(expected_type, s); + if (!s->status.ok()) return nullptr; + } else { + expected_type = operands[current_operand].getType(); + } + if (!input_arg.type_attr().empty()) { + attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type); + } + ++current_operand; + } + + for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) { + int original_size = state_->types.size(); + if (!output_arg.number_attr().empty()) { + // Same type repeated "repeats" times. + Attribute repeats_attr = attrs_[output_arg.number_attr()]; + if (!repeats_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.number_attr(), + "' required for output list '", output_arg.name(), "'"); + return nullptr; + } + if (!repeats_attr.isa()) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.number_attr(), + "' required for output list '", output_arg.name(), + "' isn't an integer"); + return nullptr; + } + int64_t repeats = repeats_attr.cast().getInt(); + + if (!output_arg.type_attr().empty()) { + // Same type repeated "repeats" times. + Attribute attr = attrs_[output_arg.type_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + TypeAttr type_attr = attr.dyn_cast(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_attr(), "' required for output '", + output_arg.name(), "' isn't a type attribute"); + return nullptr; + } + for (int i = 0; i < repeats; ++i) + state_->types.push_back(type_attr.getType()); + } else if (output_arg.type() != tensorflow::DT_INVALID) { + for (int i = 0; i < repeats; ++i) { + Type type; + s->status = + ConvertDataTypeToTensor(output_arg.type(), builder, &type); + if (!s->status.ok()) return nullptr; + state_->types.push_back(type); + } + } else { + s->status = tensorflow::errors::InvalidArgument( + "Missing type or type_attr field in ", + output_arg.ShortDebugString()); + return nullptr; + } + } else if (!output_arg.type_attr().empty()) { + Attribute attr = attrs_[output_arg.type_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + TypeAttr type_attr = attr.dyn_cast(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_attr(), "' required for output '", + output_arg.name(), "' isn't a type attribute"); + return nullptr; + } + state_->types.push_back(type_attr.getValue()); + } else if (!output_arg.type_list_attr().empty()) { + // This is pointing to an attribute which is an array of types. + Attribute attr = attrs_[output_arg.type_list_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + ArrayAttr array_attr = attr.dyn_cast(); + if (!array_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' isn't an array attribute"); + return nullptr; + } + for (Attribute attr : array_attr) { + TypeAttr type_attr = attr.dyn_cast(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Array Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' has a non-Type element"); + return nullptr; + } + state_->types.push_back(type_attr.getValue()); + } + } else if (output_arg.type() != tensorflow::DT_INVALID) { + Type type; + Builder builder(context_); + s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type); + if (!s->status.ok()) return nullptr; + state_->types.push_back(type); + } else { + s->status = tensorflow::errors::InvalidArgument( + "No type fields in ", output_arg.ShortDebugString()); + if (!s->status.ok()) return nullptr; + } + if (output_arg.is_ref()) { + // For all types that were added by this function call, make them refs. + for (Type& type : llvm::make_range(&state_->types[original_size], + state_->types.end())) { + type = AddRef(type, s); + if (!s->status.ok()) return nullptr; + } + } + } + return state_.get(); +} + +TF_Function* MlirFunction::GetTfFunction(TF_Status* s) { + PassManager pm(func_.getContext()); + pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); + pm.addNestedPass(CreateBreakUpIslandsPass()); + + // In case of failure, the `diag_handler` converts MLIR errors emitted to + // the MLIRContext into a tensorflow::Status. + StatusScopedDiagnosticHandler diag_handler(func_.getContext()); + LogicalResult result = pm.run(func_.getParentOfType()); + (void)result; + s->status = diag_handler.ConsumeStatus(); + if (!s->status.ok()) return nullptr; + + tensorflow::GraphExportConfig configs; + std::unique_ptr tf_function(new TF_Function); + s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs, + &tf_function->fdef); + return tf_function.release(); +} + +void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op, + int num_inputs, + AbstractTensor* const* inputs, + OutputList* o, TF_Status* s) { + auto* mlir_op = dyncast(abstract_op); + if (mlir_op == nullptr) { + s->status = tensorflow::errors::InvalidArgument( + "Unable to cast AbstractOp to TF_GraphOp."); + return; + } + SmallVector operands; + for (int i = 0; i < num_inputs; ++i) { + auto* operand = dyncast(inputs[i]); + if (!operand) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing eager tensors is not supported yet."); + return; + } + if (operand->getValue().getContext() != context_.get()) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing tensors from other context is not supported."); + return; + } + operands.push_back(operand->getValue()); + } + OperationState* state = mlir_op->Create(operands, s); + if (!s->status.ok() || !state) return; + Operation* op = builder_.createOperation(*state); + int num_results = op->getNumResults(); + o->outputs.clear(); + o->outputs.reserve(num_results); + for (Value result : op->getResults()) + o->outputs.push_back(new MlirTensor(result)); +} + +AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype, + TF_Status* s) { + Type type; + s->status = ConvertDataTypeToTensor(static_cast(dtype), + builder_, &type); + if (!s->status.ok()) return nullptr; + return new MlirTensor(func_.getBody().front().addArgument(type)); +} + +AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs, + TF_Status* s) { + Block& body = func_.getBody().front(); + SmallVector ret_operands; + for (AbstractTensor* output : outputs->outputs) { + auto* operand = dyncast(output); + if (!operand) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing eager tensors is not supported yet."); + return nullptr; + } + if (operand->getValue().getContext() != context_.get()) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing tensors from other context is not supported."); + return nullptr; + } + ret_operands.push_back(operand->getValue()); + } + builder_.create(func_.getLoc(), ret_operands); + + auto arg_types = llvm::to_vector<8>(body.getArgumentTypes()); + auto result_types = + llvm::to_vector<8>(body.getTerminator()->getOperandTypes()); + func_.setType(FunctionType::get(arg_types, result_types, func_.getContext())); + return new MlirFunction(std::move(context_), std::move(module_), func_); +} + +extern "C" { +ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { + RegisterDialects(); + return new MlirFunctionContext(fn_name); +} +} + +} // end anonymous namespace +} // end namespace TF +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc new file mode 100644 index 00000000000..778f4b777a3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" + +using tensorflow::internal::ExecutionContext; + +extern "C" { +ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s); +} + +namespace { +// Register the tracing implemented in this file as the default tracing engine. +static bool register_tracing = [] { + RegisterTracingEngineFactory("mlir", MlirTracingFactory); + return true; +}(); + +} // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h index 15a4ecfc537..39245425a5a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc index ac468d9810c..c95d7b7ca7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -31,5 +32,6 @@ static DialectRegistration tf_device_dialect; static DialectRegistration tf_saved_model_dialect; +static DialectRegistration shape_dialect; } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc index 6797c04ebcf..dfad1fce26d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "mlir/IR/Attributes.h" // from @llvm-project + namespace mlir { namespace TF { @@ -45,6 +47,28 @@ struct ShapeAttrStorage : public AttributeStorage { bool unranked = false; }; +// The storage class for FuncAttr. +struct FuncAttrStorage : public AttributeStorage { + using KeyTy = std::pair; + + explicit FuncAttrStorage(Attribute name, Attribute attrs) + : name(name), attrs(attrs) {} + + bool operator==(const KeyTy& key) const { return key == KeyTy(name, attrs); } + static unsigned hashKey(const KeyTy& key) { + return llvm::hash_combine(key.first, key.second); + } + + static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator, + const KeyTy& key) { + return new (allocator.allocate()) + FuncAttrStorage(key.first, key.second); + } + + Attribute name; + Attribute attrs; +}; + } // namespace detail // Get or create a shape attribute. @@ -85,5 +109,24 @@ bool ShapeAttr::hasStaticShape() const { return true; } +FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name, + DictionaryAttr attr) { + auto symbol = SymbolRefAttr::get(name, context); + return Base::get(context, AttrKind::FUNC, symbol, attr); +} + +FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol, + DictionaryAttr attr) { + return Base::get(context, AttrKind::FUNC, symbol, attr); +} + +SymbolRefAttr FuncAttr::GetName() const { + return getImpl()->name.cast(); +} + +DictionaryAttr FuncAttr::GetAttrs() const { + return getImpl()->attrs.cast(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h index 4d85dd95cea..1edc7356ab4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project namespace mlir { @@ -25,11 +26,12 @@ namespace TF { namespace AttrKind { -// List of supported custom TensorFlow Attributes kinds, necessary for +// List of supported custom TensorFlow Attribute kinds, necessary for // isa/dyn_cast. enum Kind { FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR, SHAPE = FIRST_USED_TENSORFLOW_ATTR, + FUNC, LAST_USED_TENSORFLOW_ATTR, }; @@ -38,6 +40,7 @@ enum Kind { namespace detail { struct ShapeAttrStorage; +struct FuncAttrStorage; } // namespace detail @@ -71,6 +74,33 @@ class ShapeAttr : public Attribute::AttrBase. It is +// currently printed and parsed for the following format: +// +// #tf.func<@symbol, {attr = "value"}> +// +// where the first element is the SymbolRefAttr and the second element is the +// DictionaryAttr. +class FuncAttr + : public Attribute::AttrBase { + public: + using Base::Base; + + static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name, + DictionaryAttr attr); + + static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol, + DictionaryAttr attr); + + SymbolRefAttr GetName() const; + + DictionaryAttr GetAttrs() const; + + static bool kindof(unsigned kind) { return kind == AttrKind::FUNC; } +}; + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index d5ecbf3e292..9daebc22ba1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -47,37 +47,6 @@ limitations under the License. namespace mlir { namespace tf_executor { -namespace { - -// If the given tensor has elements of type with subtypes, then returns a new -// type after dropping subtypes info. Otherwise, returns the original type as -// is. -ShapedType DropTypeSubTypes(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto subtype_ty = element_ty.dyn_cast(); - if (!subtype_ty) return ty; - - Type default_ty = GetDefaultTypeOf(subtype_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); -} - -// If the given tensor has elements of type ref, then returns a new type -// of the shape, but corresponding non-ref type as element type. Otherwise, -// returns the original type as is. -ShapedType DropRefType(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto ref_ty = element_ty.dyn_cast(); - if (!ref_ty) return ty; - - Type default_ty = GetDefaultTypeOf(ref_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); -} - -} // namespace //===----------------------------------------------------------------------===// // TF Executor Dialect @@ -85,6 +54,9 @@ ShapedType DropRefType(ShapedType ty) { namespace { +using TF::DropRefType; +using TF::DropTypeSubTypes; + struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 2b3dd529c3b..59443ce3547 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -192,6 +192,44 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } +def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> { + let summary = "An Op to exchange data across TPU replicas."; + + let description = [{ +On each replica, the input is split into `split_count` blocks along +`split_dimension` and send to the other replicas given group_assignment. After +receiving `split_count` - 1 blocks from other replicas, we concatenate the +blocks along `concat_dimension` as the output. + +For example, suppose there are 2 TPU replicas: +replica 0 receives input: `[[A, B]]` +replica 1 receives input: `[[C, D]]` + +group_assignment=`[[0, 1]]` +concat_dimension=0 +split_dimension=1 +split_count=2 + +replica 0's output: `[[A], [C]]` +replica 1's output: `[[B], [D]]` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + I32Tensor:$group_assignment, + + I64Attr:$concat_dimension, + I64Attr:$split_dimension, + I64Attr:$split_count + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Returns the argument of a complex number."; @@ -1157,6 +1195,46 @@ subsequent operation and then be optimized away, however.) }]; } +def TF_CaseOp : TF_Op<"Case", []> { + let summary = [{ +An n-way switch statement which calls a single branch function. + }]; + + let description = [{ +An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + output = branches[0](input); + break; + case 1: + output = branches[1](input); + break; + ... + case [[nbranches-1]]: + default: + output = branches[nbranches-1](input); + break; + } + ``` + }]; + + let arguments = (ins + I32Tensor:$branch_index, + Variadic:$input, + + Confined]>:$branches, + DefaultValuedAttr:$output_shapes + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; +} + def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Cast x of type SrcT to y of DstT."; @@ -1217,7 +1295,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> { +def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> { let summary = "Clips tensor values to a specified min and max."; let description = [{ @@ -1240,6 +1318,103 @@ greater than `clip_value_max` are set to `clip_value_max`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CollectiveBcastRecvOp : TF_Op<"CollectiveBcastRecv", []> { + let summary = "Receives a tensor value broadcast from another device."; + + let description = [{ + }]; + + let arguments = (ins + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_ShapeAttr:$shape, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I1, I32, I64]>:$data + ); + + TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>; +} + +def TF_CollectiveBcastSendOp : TF_Op<"CollectiveBcastSend", []> { + let summary = "Broadcasts a tensor value to one or more other devices."; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I1, I32, I64]>:$input, + + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_ShapeAttr:$shape, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I1, I32, I64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CollectiveGatherOp : TF_Op<"CollectiveGather", []> { + let summary = [{ +Mutually accumulates multiple tensors of identical type and shape. + }]; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I32, I64]>:$input, + + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_ShapeAttr:$shape, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I32, I64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CollectiveReduceOp : TF_Op<"CollectiveReduce", [SameOperandsAndResultType]> { + let summary = [{ +Mutually reduces multiple tensors of identical type and shape. + }]; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I32, I64]>:$input, + + I64Attr:$group_size, + I64Attr:$group_key, + I64Attr:$instance_key, + TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op, + TF_AnyStrAttrOf<["Id", "Div"]>:$final_op, + I64ArrayAttr:$subdiv_offsets, + DefaultValuedAttr:$wait_for, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I32, I64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> { let summary = "Converts two real numbers to a complex number."; @@ -1408,6 +1583,30 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] let hasCanonicalizer = 1; } +def TF_ConjugateTransposeOp : TF_Op<"ConjugateTranspose", [NoSideEffect]> { + let summary = [{ +Shuffle dimensions of x according to a permutation and conjugate the result. + }]; + + let description = [{ +The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: + `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` + `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])` + }]; + + let arguments = (ins + TF_Tensor:$x, + TF_I32OrI64Tensor:$perm + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>; +} + def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> { let summary = [{ Computes a 2-D convolution given 4-D `input` and `filter` tensors. @@ -1682,7 +1881,28 @@ Given an input tensor, this function computes hyperbolic cosine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> { +def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> { + let summary = "Compute the pairwise cross product."; + + let description = [{ +`a` and `b` must be the same shape; they can either be simple 3-element vectors, +or any shape where the innermost dimension is 3. In the latter case, each pair +of corresponding 3-element vectors is cross-multiplied independently. + }]; + + let arguments = (ins + TF_IntOrFpTensor:$a, + TF_IntOrFpTensor:$b + ); + + let results = (outs + TF_IntOrFpTensor:$product + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { let summary = "An Op to sum inputs across replicated TPU instances."; let description = [{ @@ -1706,7 +1926,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> { +def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; let description = [{ @@ -2427,6 +2647,76 @@ This operation creates a tensor of `shape` and `dtype`. let hasFolder = 1; } +def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [SameVariadicOperandSize]> { + let summary = "Eases the porting of code that uses tf.nn.embedding_lookup()."; + + let description = [{ +sample_splits[i], embedding_indices[i] and aggregation_weights[i] correspond +to the ith feature. table_ids[i] indicates which embedding table to look up ith +feature. + +The tensors at corresponding positions in two of the input lists, +embedding_indices and aggregation_weights, must have the same shape, i.e. rank 1 +with dim_size() equal to the total number of lookups into the table described by +the corresponding feature. + }]; + + let arguments = (ins + Variadic:$sample_splits, + Variadic:$embedding_indices, + Variadic:$aggregation_weights, + TF_StrTensor:$mode_override, + + DefaultValuedAttr:$device_ordinal, + DefaultValuedAttr:$combiners, + I64ArrayAttr:$table_ids, + DefaultValuedAttr:$max_sequence_lengths + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T1 = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr T2 = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; +} + +def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [SameVariadicOperandSize]> { + let summary = [{ +Eases the porting of code that uses tf.nn.embedding_lookup_sparse(). + }]; + + let description = [{ +sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond +to the ith feature. table_ids[i] indicates which embedding table to look up ith +feature. + +The tensors at corresponding positions in the three input lists (sample_indices, +embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1 +with dim_size() equal to the total number of lookups into the table described by +the corresponding feature. + }]; + + let arguments = (ins + Variadic:$sample_indices, + Variadic:$embedding_indices, + Variadic:$aggregation_weights, + TF_StrTensor:$mode_override, + + DefaultValuedAttr:$device_ordinal, + DefaultValuedAttr:$combiners, + I64ArrayAttr:$table_ids, + DefaultValuedAttr:$max_sequence_lengths + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr T1 = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr T2 = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; +} + def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; @@ -2824,6 +3114,8 @@ fill([2, 3], 9) ==> [[9, 9, 9] return Verify(*this); }]; + let hasFolder = 1; + let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value dims, Value value" >]; @@ -3256,8 +3548,8 @@ Gather slices from `params` axis `axis` according to `indices`. let description = [{ `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -Produces an output tensor with shape `params.shape[:axis] + indices.shape + -params.shape[axis + 1:]` where: +Produces an output tensor with shape `params.shape[:axis] + +indices.shape[batch_dims:] + params.shape[axis + 1:]` where: ```python # Scalar indices (output is rank(params) - 1). @@ -4242,7 +4534,7 @@ cublas. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> { +def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> { let summary = [{ Copy a tensor setting everything outside a central band in each innermost matrix to zero. }]; @@ -6246,6 +6538,8 @@ If `x` and `y` are reals, this will return the floating-point division. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { @@ -7349,9 +7643,15 @@ select(condition, t, e) ==> [[1, 2], ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; + + let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; } -def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { +def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> { let summary = ""; let description = [{ @@ -8224,7 +8524,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> { ); } -def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> { +def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { let summary = "Stops gradient computation."; let description = [{ @@ -10244,6 +10544,33 @@ https://www.tensorflow.org/xla/operation_semantics#gather TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> { + let summary = [{ +A pseudo-op to represent host-side computation in an XLA program. + }]; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs, + + StrArrayAttr:$ancestors, + TF_ShapeAttrArray:$shapes, + SymbolRefAttr:$shape_inference_graph, + StrAttr:$key, + DefaultValuedAttr:$cost_estimate_ns, + DefaultValuedAttr:$tpu_core + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> { let summary = "Wraps the XLA Sort operator, documented at"; @@ -10292,6 +10619,24 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> { + let summary = "An op to receive a tensor from the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_ShapeAttr:$shape, + StrAttr:$key + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>; +} + def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> { let summary = "Wraps the XLA Reduce operator, documented at"; @@ -10356,6 +10701,23 @@ i=0...N-1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> { + let summary = "An op to send a tensor to the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$key + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> { let summary = [{ Computes the eigen decomposition of a batch of self-adjoint matrices @@ -10439,6 +10801,27 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> { + let summary = "A host-side computation called from a TPU device."; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs, + + StrAttr:$key, + DefaultValuedAttr:$tpu_core + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> { let summary = "An op that receives embeddng activations on the TPU."; @@ -10497,3 +10880,44 @@ used to look up the program in the compilation cache. TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>; TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>; } + +def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> { + let summary = [{ +A placeholder op to receive values from a running XLA computation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + +def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> { + let summary = "A placeholder op to send values to a running XLA computation."; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs, + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index cb17341cefd..dbd8ab0fae2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -23,7 +23,7 @@ limitations under the License. #define TF_OP_BASE include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" //===----------------------------------------------------------------------===// @@ -70,6 +70,16 @@ class TF_OpIsBroadcastableToRes : And<[ "$_op.getOperand(" # opId # ").getType(), " "$_op.getResult(" # resId # ").getType())">]>; + +class TF_AllTypesMatchPred values> : + CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin.result #"}))">; + +class TF_AllTypesMatch names> : + PredOpTrait< + "all of {" # StrJoin.result # "} have dynamically equal types ", + TF_AllTypesMatchPred< + !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; + //===----------------------------------------------------------------------===// // TensorFlow op definitions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 3a4e9e5985e..c9d61abe507 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -58,6 +58,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/platform/logging.h" @@ -110,48 +111,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) { return !type || type.getRank() <= rank; } -// Returns true if the given pair of TensorFlow types can be cast to one -// another. In other words, a single run-time value is legal for both the types. -// For example, tensor<*xf32> and tensor<3xf32> are cast compatible. -static bool AreCastCompatible(Type a, Type b) { - if (TensorCastOp::areCastCompatible(a, b)) return true; - - // Resource types may optionally contain subtypes information that does not - // match. Check subtypes compatibility when possible, otherwise treat them as - // compatible. - auto a_or_element_type = getElementTypeOrSelf(a); - auto b_or_element_type = getElementTypeOrSelf(b); - - auto a_kind = a_or_element_type.getKind(); - auto b_kind = b_or_element_type.getKind(); - - if (a_kind == TensorFlowTypes::RESOURCE && - b_kind == TensorFlowTypes::RESOURCE) { - auto a_resource_type = a_or_element_type.dyn_cast(); - auto b_resource_type = b_or_element_type.dyn_cast(); - bool a_has_subtype = !a_resource_type.getSubtypes().empty(); - bool b_has_subtype = !b_resource_type.getSubtypes().empty(); - - if (!a_has_subtype || !b_has_subtype) return true; - - assert(a_resource_type.getSubtypes().size() <= 1 && - "Resource type must have at most one subtype"); - assert(b_resource_type.getSubtypes().size() <= 1 && - "Resource type must have at most one subtype"); - - return TensorCastOp::areCastCompatible( - a_resource_type.getSubtypes().front(), - b_resource_type.getSubtypes().front()); - } - - // Variant types may optionally contain subtypes information that need not - // match. It is also not possible to compare subtypes for compatibility as - // their interpretation depends on the ops operating on them. So, accept all - // pairs of variant types. - return a_kind == TensorFlowTypes::VARIANT && - b_kind == TensorFlowTypes::VARIANT; -} - static bool IsUnknownDimOrRank(int64_t dim_or_rank) { return dim_or_rank == -1; } @@ -293,6 +252,39 @@ static LogicalResult VerifyTypesCompatibility( return success(); } +// This is a helper for the Select to SelectV2 canonicalization. The `data` rank +// refers to the rank of `t`/`e` (these two inputs have equal rank; this is +// checked in the verifier). +// +// In most cases, the predicate for Select can be used directly as the predicate +// for SelectV2. However, there is one case that varies, which is when the +// predicate is a tensor and the data is multidimensional. In this case, Select +// op semantics dictate that the predicate tensor length must match the size of +// the first data dimension. This varies from normal broadcasting semantics +// (which are used in SelectV2), so we must reshape the tensor in this case to +// be compatible. +static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, + Value cond, int data_rank) { + auto cond_tensor = cond.getType().cast(); + // Reshape is only needed in the case that the cond rank is 1 (i.e. it is + // a vector) AND t/e rank is > 1. + if (cond_tensor.getRank() != 1 || data_rank <= 1) { + // No reshape necessary. Leave cond as it is. + return cond; + } + + // This is the case where a reshape is needed. We want to construct the + // shape [x,1,...1], where x is the value in the pred tensor and the + // length of the shape is equal to data_rank. + SmallVector shape(data_rank, 1); + shape[0] = cond_tensor.getShape().front(); + auto new_shape_type = + RankedTensorType::get({data_rank}, builder->getIntegerType(64)); + auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); + auto new_shape = builder->create(loc, shape_attr); + return builder->create(loc, cond, new_shape); +} + //===----------------------------------------------------------------------===// // Helper functions detect device capabilities from RuntimeDevices. //===----------------------------------------------------------------------===// @@ -503,9 +495,10 @@ LogicalResult FoldOperandsPermutation( namespace { // Folder that returns LHS of an Arithmetic Op if the RHS is a constant // known to be Identity (e.g X+0) -template ::value>::type * = nullptr> +template < + typename OpT, + typename std::enable_if::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef operands) { auto result_op_type = arithmetic_op.getResult().getType(); @@ -520,7 +513,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, // Mul and Div ops have identity value one while AddV2 and SubOp have identity // value zero. int identity = - (std::is_same::value || std::is_same::value); + (std::is_same::value || std::is_same::value || + std::is_same::value); Type element_ty = lhs_type.getElementType(); Attribute identity_attr; @@ -537,6 +531,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, return arithmetic_op.x(); } + auto rhs_type = arithmetic_op.y().getType().template cast(); + // TODO(chhe): we could fold and add an identity to force the broadcast. + if (result_op_type != rhs_type) { + return {}; + } + bool is_symmetric = (std::is_same::value || std::is_same::value); if (auto attr = operands[0].dyn_cast_or_null()) { @@ -984,20 +984,17 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, LogicalResult ConstOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - for (NamedAttribute named_attr : attributes) { - if (named_attr.first.strref() != "value") continue; - auto value = named_attr.second; - if (auto elem_attr = value.dyn_cast()) { - inferredReturnTypes.assign({elem_attr.getType()}); - return success(); - } - return emitOptionalError(location, - "attribute 'value' failed to satisfy constraint: " - "constant vector/tensor"); + auto value = attributes.get("value"); + if (!value) return emitOptionalError(location, "missing attribute 'value'"); + if (auto elem_attr = value.dyn_cast()) { + inferredReturnTypes.assign({elem_attr.getType()}); + return success(); } - return emitOptionalError(location, "missing attribute 'value'"); + return emitOptionalError(location, + "attribute 'value' failed to satisfy constraint: " + "constant vector/tensor"); } //===----------------------------------------------------------------------===// @@ -1300,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) { if (rank == 1) { int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4) - return op.emitOpError("requires 1D input of size 4"); + if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) + return op.emitOpError("requires 1D input of size 4 or size 2"); } if (rank == 2) { @@ -1416,7 +1413,7 @@ static LogicalResult Verify(DynamicStitchOp op) { auto expected_out_ty = RankedTensorType::get(expected_shape, out_ty.getElementType()); - if (!AreCastCompatible(out_ty, expected_out_ty)) { + if (!AreCastCompatible({out_ty, expected_out_ty})) { return op.emitOpError() << "has invalid output type; should be " "compatible with inferred type " << expected_out_ty; @@ -1650,7 +1647,7 @@ static ShapedType InferFillOpType(Value dims, Value value) { llvm::SmallVector shape; shape.reserve(dims_attr.getNumElements()); - for (const APInt &dim : dims_attr.getValues()) { + for (const APInt dim : dims_attr.getValues()) { shape.push_back(dim.getSExtValue()); } return RankedTensorType::get(shape, etype); @@ -1661,6 +1658,35 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); } +OpFoldResult FillOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "fill op has two operand"); + + auto type = getType().cast(); + // DenseElementsAttr that is used in this folder only supports int and float + // types. + // TODO(hinsu): Handle complex types once there is a attribute kind for + // complex. + if (!type.getElementType().isIntOrFloat()) return {}; + + auto value = operands[1].dyn_cast_or_null(); + if (!value) return {}; + + if (type.hasStaticShape()) + return DenseElementsAttr::get(type, value.getValue({})); + + auto dims = operands[0].dyn_cast_or_null(); + if (!dims) return {}; + + llvm::SmallVector shape; + shape.reserve(dims.getNumElements()); + for (const APInt dim : dims.getValues()) { + shape.push_back(dim.getSExtValue()); + } + type = RankedTensorType::get(shape, type.getElementType()); + + return DenseElementsAttr::get(type, value.getValue({})); +} + //===----------------------------------------------------------------------===// // FusedBatchNormGradOp //===----------------------------------------------------------------------===// @@ -1795,75 +1821,129 @@ static LogicalResult Verify(GatherV2Op op) { static LogicalResult Verify(IfOp op) { auto module = op.getParentOfType(); - auto thenFn = module.lookupSymbol(op.then_branch()); - if (!thenFn) + auto then_fn = module.lookupSymbol(op.then_branch()); + if (!then_fn) return op.emitOpError("then_branch refers to an undefined function : ") << op.then_branch(); - auto elseFn = module.lookupSymbol(op.else_branch()); - if (!elseFn) + auto else_fn = module.lookupSymbol(op.else_branch()); + if (!else_fn) return op.emitOpError("else_branch refers to an undefined function : ") << op.else_branch(); - auto thenFuncType = thenFn.getType(); - auto elseFuncType = elseFn.getType(); + auto then_fn_type = then_fn.getType(); + auto else_fn_type = else_fn.getType(); // Non-conditional operands starting with the second operand are passed to // branches and should be pair-wise compatible with branches' inputs. - unsigned expectedNumInputs = op.getNumOperands() - 1; - if (thenFuncType.getNumInputs() != expectedNumInputs || - elseFuncType.getNumInputs() != expectedNumInputs) - return op.emitError("branches should have " + Twine(expectedNumInputs) + + unsigned expected_num_inputs = op.getNumOperands() - 1; + if (then_fn_type.getNumInputs() != expected_num_inputs || + else_fn_type.getNumInputs() != expected_num_inputs) + return op.emitError("branches should have " + Twine(expected_num_inputs) + " inputs"); - for (unsigned i = 0; i < expectedNumInputs; ++i) { - auto operandType = op.getOperand(i + 1).getType().cast(); - auto thenInputType = thenFuncType.getInput(i).cast(); - if (!AreCastCompatible(operandType, thenInputType)) + for (unsigned i = 0; i < expected_num_inputs; ++i) { + auto operand_type = op.getOperand(i + 1).getType().cast(); + auto then_input_type = then_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, then_input_type})) return op.emitError( llvm::formatv("then branch input type {0} is incompatible with " "operand type {1} at index {2}", - thenInputType, operandType, i)); + then_input_type, operand_type, i)); - auto elseInputType = elseFuncType.getInput(i).cast(); - if (!AreCastCompatible(operandType, elseInputType)) + auto else_input_type = else_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, else_input_type})) return op.emitError( llvm::formatv("else branch input type {0} is incompatible with " "operand type {1} at index {2}", - elseInputType, operandType, i)); + else_input_type, operand_type, i)); // If branches have incompatible input types that means that no tensor can // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible(thenInputType, elseInputType)) + if (!AreCastCompatible({then_input_type, else_input_type})) return op.emitError(llvm::formatv( "branches inputs have incompatible types {0} and {1} at index {2}", - thenInputType, elseInputType, i)); + then_input_type, else_input_type, i)); } // Branches' results should be pair-wise compatible with the op results. - unsigned expectedNumResults = op.getNumResults(); - if (thenFuncType.getNumResults() != expectedNumResults || - elseFuncType.getNumResults() != expectedNumResults) - return op.emitError("branches should have " + Twine(expectedNumResults) + + unsigned expected_num_results = op.getNumResults(); + if (then_fn_type.getNumResults() != expected_num_results || + else_fn_type.getNumResults() != expected_num_results) + return op.emitError("branches should have " + Twine(expected_num_results) + " results"); - for (unsigned i = 0; i < expectedNumResults; ++i) { - auto resultType = op.getResult(i).getType().cast(); - auto thenResultType = thenFuncType.getResult(i).cast(); - if (!AreCastCompatible(thenResultType, resultType)) + for (unsigned i = 0; i < expected_num_results; ++i) { + auto result_type = op.getResult(i).getType().cast(); + auto then_result_type = then_fn_type.getResult(i).cast(); + if (!AreCastCompatible({then_result_type, result_type})) return op.emitError( llvm::formatv("then branch result type {0} is incompatible with op " "result type {1} at index {2}", - thenResultType, resultType, i)); + then_result_type, result_type, i)); - auto elseResultType = elseFuncType.getResult(i).cast(); - if (!AreCastCompatible(elseResultType, resultType)) + auto else_result_type = else_fn_type.getResult(i).cast(); + if (!AreCastCompatible({else_result_type, result_type})) return op.emitError( llvm::formatv("else branch result type {0} is incompatible with op " "result type {1} at index {2}", - elseResultType, resultType, i)); + else_result_type, result_type, i)); } return success(); } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(YieldOp op) { + auto parent = op.getParentOp(); + // A YieldOp should be contained within an IfRegion op + // (and WhileRegion in future) + if (!isa(parent)) + op.emitError() << " expects parent op " + << "'" << IfRegionOp::getOperationName() << "' but got '" + << parent->getName().getStringRef() << "'"; + return success(); +} + +//===----------------------------------------------------------------------===// +// IfRegionOp +//===----------------------------------------------------------------------===// + +LogicalResult VerifyRegionResults(Operation *op, Region ®ion, + StringRef region_name) { + auto op_name = op->getName().getStringRef(); + // verify that op outputs match yield inputs + YieldOp yield = cast(region.front().getTerminator()); + unsigned expected_num_results = op->getNumResults(); + if (yield.getNumOperands() != expected_num_results) + return op->emitError(region_name + " region should have " + + Twine(expected_num_results) + " results"); + + for (int idx : llvm::seq(0, expected_num_results)) { + auto op_result_type = op->getResult(idx).getType().cast(); + auto region_result_type = + yield.getOperand(idx).getType().cast(); + if (!AreCastCompatible({region_result_type, op_result_type})) + return op->emitError(llvm::formatv( + "{0} result type {1} is incompatible with {2} " + "result type {3} at index {4}", + region_name, region_result_type, op_name, op_result_type, idx)); + } + return success(); +} + +static LogicalResult Verify(IfRegionOp op) { + if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) + return failure(); + if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) + return failure(); + if (op.then_branch().front().getNumArguments() != 0) + return op.emitOpError() << "then region cannot have any arguments"; + if (op.else_branch().front().getNumArguments() != 0) + return op.emitOpError() << "else region cannot have any arguments"; + return success(); +} + //===----------------------------------------------------------------------===// // InvertOp //===----------------------------------------------------------------------===// @@ -2429,6 +2509,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult RealDivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -2560,6 +2644,81 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, return unranked(); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +// Verifies a few extra requirements on SelectOp: +// (1) `then` and `else` must have same shape +// (2) At least one of the following must be true: +// (a) `cond` has the same rank as `then` and `else` +// (b) `cond` is a scalar +// (c) `cond` is a vector AND `then` and `else` are non-scalar with their +// first dimension equal to `cond`. +static LogicalResult Verify(SelectOp op) { + auto then_tensor = op.t().getType().cast(); + auto else_tensor = op.e().getType().cast(); + // Check (1). + if (!AreCastCompatible({then_tensor, else_tensor})) + return op.emitOpError() << "requires t and e have compatible shapes"; + + // Get data rank (if exists). + int data_rank; + // If data is unranked or data_rank is 0, this will remain -2. Otherwise + // refers to first dimension of then and/or else. + int data_first_dim = -2; + bool then_has_rank = then_tensor.hasRank(); + bool else_has_rank = else_tensor.hasRank(); + if (then_has_rank && else_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + if (else_tensor.getRank() > 0) + data_first_dim = std::max( + static_cast(else_tensor.getShape().front()), data_first_dim); + } else if (then_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + } else if (else_has_rank) { + data_rank = else_tensor.getRank(); + if (else_tensor.getRank() > 0) + data_first_dim = else_tensor.getShape().front(); + } else { + // Neither has a rank. + return success(); + } + + auto cond_tensor = op.condition().getType().dyn_cast(); + if (!cond_tensor) return success(); + auto cond_rank = cond_tensor.getRank(); + // Check (2a) and (2b). + if (cond_rank == 0 || cond_rank == data_rank) return success(); + // Check (2c). + if (cond_rank == 1) { + auto cond_shape = cond_tensor.getShape().front(); + if (data_rank == 0) { + return op.emitOpError() + << "requires that t and e are nonscalar when pred is a vector"; + } + // We know `data` tensor has a rank of at least 1. + if (data_first_dim != -1 && cond_shape != -1 && + data_first_dim != cond_shape) { + return op.emitOpError() << "requires that, when pred is a vector, the " + "shape matches the first dimension of t and e"; + } + return success(); + } + // None of (2a,b,c) were true; fail. + return op.emitOpError() << "requires that pred is a scalar OR has the same " + "rank as t and e OR is a vector"; +} + //===----------------------------------------------------------------------===// // SelectV2Op //===----------------------------------------------------------------------===// @@ -2619,9 +2778,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, << variadic_idx_str << " to match rank of operand" << variadic_idx_str; } else if (result_ranked_type.hasStaticShape()) { - // The operand is an unranked tensor, verify that the result is dynamic. - return op->emitOpError("requires dynamic shape result") - << variadic_idx_str << " for unranked operand" << variadic_idx_str; + // The operand is an unranked tensor, print a warning if the result + // is static. + // Note: We do not handle this situation as an error, this would be too + // restrictive due to incompleteness of shape inference at this point. + op->emitWarning("has static shape result") + << variadic_idx_str << " for unranked operand" << variadic_idx_str; } Type element_type = result_ranked_type.getElementType(); @@ -3572,12 +3734,20 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) { if (!const_perm) return {}; auto const_value = const_perm.value(); - const auto &elements = const_value.getValues(); + const auto elements = const_value.getValues(); for (auto it : llvm::enumerate(elements)) { if (it.index() != it.value()) return {}; } + // TODO(jpienaar): Remove if/when we handle this more generally. + if (op.getType() != op.x().getType()) { + // If the types don't match then only fold if all the operands are in the TF + // dialect. + for (auto user : op.getOperation()->getUsers()) + if (user->getDialect() != op.getDialect()) return {}; + } + return op.x(); } @@ -3721,36 +3891,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef operands) { static LogicalResult Verify(WhileOp op) { auto module = op.getParentOfType(); - auto condFn = module.lookupSymbol(op.cond()); - auto bodyFn = module.lookupSymbol(op.body()); - if (!condFn) { + auto cond_fn = module.lookupSymbol(op.cond()); + auto body_fn = module.lookupSymbol(op.body()); + if (!cond_fn) { return op.emitOpError("cond refers to an undefined function : ") << op.cond(); } - if (!bodyFn) { + if (!body_fn) { return op.emitOpError("body refers to an undefined function : ") << op.body(); } - auto condFuncType = condFn.getType(); - auto bodyFuncType = bodyFn.getType(); + auto cond_fn_type = cond_fn.getType(); + auto body_fn_type = body_fn.getType(); // Verify that the cond function has exactly one result. - if (condFuncType.getNumResults() != 1) + if (cond_fn_type.getNumResults() != 1) return op.emitOpError("requires cond function to have exactly one result"); SmallVector operands(op.getOperandTypes()); // Collect all the type lists for the op so that different pairs of type lists // can be compared for the compatibility. - int numTypeLists = 5; - std::pair> typeLists[] = { - {"operand", operands}, - {"body function result", bodyFuncType.getResults()}, - {"result", op.getResultTypes()}, - {"cond function input", condFuncType.getInputs()}, - {"body function input", bodyFuncType.getInputs()}, - }; + constexpr int kNumTypeLists = 5; + const std::array>, kNumTypeLists> + type_lists = {{ + {"operand", operands}, + {"body function result", body_fn_type.getResults()}, + {"result", op.getResultTypes()}, + {"cond function input", cond_fn_type.getInputs()}, + {"body function input", body_fn_type.getInputs()}, + }}; // A pair of type lists should be cast compatible with each other if one is // converted to the another for a function call or assignment or there is a @@ -3774,28 +3945,28 @@ static LogicalResult Verify(WhileOp op) { // never converted from one to the another nor there is a common source // tensors. Compatibility requirement is not transitive. - for (int i = 0; i < numTypeLists; ++i) { + for (int i = 0; i < kNumTypeLists; ++i) { // Skip the first pair as the While op operands and body function results // does not need to be compatible with each other. - for (int j = std::max(2, i + 1); j < numTypeLists; ++j) { - auto &a = typeLists[i]; - auto &b = typeLists[j]; + for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { + auto &a = type_lists[i]; + auto &b = type_lists[j]; - int aSize = a.second.size(); - if (aSize != b.second.size()) + int a_size = a.second.size(); + if (a_size != b.second.size()) return op.emitOpError( llvm::formatv("requires the number of {0}s to be equal to the " "number of {1}s. Found {2} and {3}, respectively", - a.first, b.first, aSize, b.second.size())); + a.first, b.first, a_size, b.second.size())); - for (int idx = 0; idx < aSize; ++idx) { - auto aType = a.second[idx]; - auto bType = b.second[idx]; + for (int idx = 0; idx < a_size; ++idx) { + auto a_type = a.second[idx]; + auto b_type = b.second[idx]; - if (!AreCastCompatible(aType, bType)) + if (!AreCastCompatible({a_type, b_type})) return op.emitError(llvm::formatv( "{0} type {1} is incompatible with {2} type {3} at index {4}", - a.first, aType, b.first, bType, idx)); + a.first, a_type, b.first, b_type, idx)); } } } @@ -3877,7 +4048,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" >(); addInterfaces(); - addAttributes(); + addAttributes(); // Support unknown operations because not all TensorFlow operations are // registered. @@ -3932,6 +4103,49 @@ void PrintShapeAttr(ShapeAttr attr, DialectAsmPrinter &os) { // NOLINT os << ">"; } +// Parses a #tf.func attribute of the following format: +// +// #tf.func<@symbol, {attr = "value"}> +// +// where the first element is a SymbolRefAttr and the second element is a +// DictionaryAttr. +FuncAttr ParseFuncAttr(MLIRContext *context, StringRef spec, Location loc) { + auto emit_error = [&, spec]() { + emitError(loc, "invalid TensorFlow func attribute: ") << spec; + return nullptr; + }; + + if (!spec.consume_front("func<")) return emit_error(); + + size_t func_name_num_read = 0; + Attribute func_name_attr = + mlir::parseAttribute(spec, context, func_name_num_read); + if (!func_name_attr || !func_name_attr.isa()) + return emit_error(); + spec = spec.drop_front(func_name_num_read); + + if (!spec.consume_front(", ")) return emit_error(); + + size_t func_attrs_num_read = 0; + Attribute func_attrs_attr = + mlir::parseAttribute(spec, context, func_attrs_num_read); + if (!func_attrs_attr || !func_attrs_attr.isa()) + return emit_error(); + spec = spec.drop_front(func_attrs_num_read); + + if (!spec.consume_front(">")) return emit_error(); + + return mlir::TF::FuncAttr::get(context, func_name_attr.cast(), + func_attrs_attr.cast()); +} + +// Prints a #tf.func attribute of the following format: +// +// #tf.func<@symbol, {attr = "value"}> +void PrintFuncAttr(FuncAttr attr, DialectAsmPrinter &os) { + os << "func<" << attr.GetName() << ", " << attr.GetAttrs() << ">"; +} + } // namespace Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser, @@ -3941,6 +4155,8 @@ Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser, if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc); + if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc); + return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr); } @@ -3950,6 +4166,9 @@ void TensorFlowDialect::printAttribute(Attribute attr, case AttrKind::SHAPE: PrintShapeAttr(attr.cast(), os); break; + case AttrKind::FUNC: + PrintFuncAttr(attr.cast(), os); + break; default: llvm_unreachable("unexpected tensorflow attribute kind"); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 979f506b3b1..88307267ab4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 744d1ac5b71..5f3a1a5be35 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -99,6 +99,30 @@ def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect, }]; } +def TF_CollectivePermuteOp : TF_Op<"CollectivePermute", []> { + let summary = "An Op to permute tensors across replicated TPU instances."; + + let description = [{ +Each instance supplies its own input. + +For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: +`[D, A, B, C]`. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + I32Tensor:$source_target_pairs + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + + def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Permute input tensor from `src_format` to `dst_format`"; @@ -188,7 +212,6 @@ else_branch: A function that takes 'inputs' and returns a list of FlatSymbolRefAttr:$then_branch, FlatSymbolRefAttr:$else_branch, - DefaultValuedAttr:$output_shapes, // Used to map StatelessIf and If op defined in TensorFlow to a common op. BoolAttr:$is_stateless @@ -207,6 +230,66 @@ else_branch: A function that takes 'inputs' and returns a list of }]; } +def TF_YieldOp : TF_Op<"Yield", [Terminator]> { + let summary = "Yield operation"; + let description = [{ + The "yield" operation represents a return operation within the conditional + and body of structured control flow (e.g., if and while). The operation + takes a variable number of operands and produces no results. The number and + types of inputs must match the signature of the operation that contains the + region. + }]; + + let arguments = (ins Variadic:$operands); + + let verifier = [{ + return Verify(*this); + }]; +} + +def TF_IfRegionOp : TF_Op<"IfRegion", + [SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "output = cond ? then_branch output : else_branch output"; + + let description = [{ +"output = cond ? then_branch output : else_branch output" + +cond: A Tensor. If the tensor is a scalar of non-boolean type, the + scalar is converted to a boolean according to the + following rule: if the scalar is a numerical value, non-zero means + True and zero means False; if the scalar is a string, non-empty + means True and empty means False. If the tensor is not a scalar, + being empty means False and being non-empty means True. +then_branch: A region that computes the outputs of the op if cond = true. + It returns a list of tensors using tf.yield (as the terminator). The + types of these returned tensors is same as that of the else_branch +else_branch: A region that computes the outputs of the op if cond = false. + It returns a list of tensors using tf.yield (as the terminator). The + types of these returned tensors is same as that of the then_branch + }]; + + let arguments = (ins + TF_Tensor:$cond, + + // Used to map StatelessIf and If op defined in TensorFlow to a common op. + BoolAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + + let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch); + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Computes the mean of elements across dimensions of a tensor."; @@ -905,5 +988,29 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; } +// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once +// autogenerated op def matches. +def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { + let summary = "Updates specified rows 'i' with values 'v'."; + + let description = [{ +Computes `x[i, :] = v; return x`. + +Originally this function is mutative however for compilation we make this +operation create / operate on a copy of `x`. + }]; + + let arguments = (ins + TF_Tensor:$x, + I32Tensor:$i, + TF_Tensor:$v + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 6c3cd7fac92..994378ea1cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -28,6 +28,134 @@ llvm::Optional> GetShape(mlir::Value value) { if (shaped_type.hasRank()) return shaped_type.getShape(); return llvm::None; } + +// Merges cast compatible shapes and returns a more refined shape. The two +// shapes are cast compatible if they have the same rank and at each dimension, +// either both have same size or one of them is dynamic. Returns false if the +// given shapes are not cast compatible. The refined shape is same or more +// precise than the two input shapes. +bool GetCastCompatibleShape(llvm::ArrayRef a_shape, + llvm::ArrayRef b_shape, + llvm::SmallVectorImpl* refined_shape) { + if (a_shape.size() != b_shape.size()) return false; + int64_t rank = a_shape.size(); + refined_shape->reserve(rank); + for (auto dims : llvm::zip(a_shape, b_shape)) { + int64_t dim1 = std::get<0>(dims); + int64_t dim2 = std::get<1>(dims); + + if (mlir::ShapedType::isDynamic(dim1)) { + refined_shape->push_back(dim2); + continue; + } + if (mlir::ShapedType::isDynamic(dim2)) { + refined_shape->push_back(dim1); + continue; + } + if (dim1 == dim2) { + refined_shape->push_back(dim1); + continue; + } + return false; + } + return true; +} + +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// +// The two types are considered cast compatible if they have dynamically equal +// shapes and element type. For element types that do not have subtypes, they +// must be equal. However for TensorFlow types such as Resource and Variant, +// that also have subtypes, we recursively check for subtype compatibilty for +// Resource types and assume all variant types are cast compatible. If either +// one of `a` or `b` have empty subtypes, they are considered cast compatible. +// +// The returned type is same or more precise than the input types. For example, +// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and +// tensor respectively, the returned type is tensor<2x4x?xf32>. +// +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a) { + // Fast path if everything is equal. + if (a == b) return b; + + auto a_tt = a.dyn_cast(); + auto b_tt = b.dyn_cast(); + + // If only one of a or b is a tensor type, they are incompatible. + if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; + + // For non-tensor types, we do not need to worry about shape and can return + // early. + if (!a_tt && !b_tt) { + // Remove ref types. + if (may_ignore_ref_type_a) { + if (auto ref_type = a.dyn_cast()) { + a = ref_type.RemoveRef(); + if (a == b) return a; + } + } + if (a.getKind() != b.getKind()) return nullptr; + + // If either is not a type that contain subtypes then the types are not cast + // compatible. + auto a_wst = a.dyn_cast(); + auto b_wst = b.dyn_cast(); + if (!a_wst || !b_wst) return nullptr; + + // For Variant types we are more permissive right now and accept all pairs + // of Variant types. If we are more constrainted and check compatibility of + // subtypes, we might reject valid graphs. + // TODO(prakalps): Variant doesn't have a subtype, we assign it + // one, so we should only assign it one when we know the subtype. Then we + // can be more constrained and check subtypes for cast compatibility as + // well. + if (a.isa()) return a; + + // For Resource types, we recursively check the subtypes for cast + // compatibility, if possible. Otherwise treat them as compatible. + auto a_wst_st = a_wst.GetSubtypes(); + auto b_wst_st = b_wst.GetSubtypes(); + if (a_wst_st.empty() || b_wst_st.empty()) return a; + if (a_wst_st.size() != b_wst_st.size()) return nullptr; + llvm::SmallVector refined_subtypes; + for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { + mlir::Type refined_st = + GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), + /*may_ignore_ref_type_a=*/false); + if (!refined_st) return nullptr; + refined_subtypes.push_back(refined_st.cast()); + } + + return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); + } + + // For tensor types, check compatibility of both element type and shape. + mlir::Type refined_element_ty = GetCastCompatibleType( + a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); + if (!refined_element_ty) return nullptr; + + if (!a_tt.hasRank() && !b_tt.hasRank()) { + return mlir::UnrankedTensorType::get(refined_element_ty); + } + if (!a_tt.hasRank()) { + return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); + } + if (!b_tt.hasRank()) { + return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); + } + + llvm::SmallVector refined_shape; + if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) + return nullptr; + + return mlir::RankedTensorType::get(refined_shape, refined_element_ty); +} } // namespace namespace mlir { @@ -224,47 +352,41 @@ bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs) { - // Fast path if everything is equal. - if (lhs == rhs) return true; + return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr; +} - // In TF all values are tensors. - auto lhs_tt = lhs.cast(); - auto rhs_tt = rhs.cast(); - - // Verify matching element types. These should be identical dynamically, - // so this allows for types not yet fully refined. - auto lhs_et = lhs_tt.getElementType(); - auto rhs_et = rhs_tt.getElementType(); - if (lhs_et == rhs_et) return true; - - // Remove ref types. - if (may_ignore_ref_type_lhs) { - if (auto ref_type = lhs_et.dyn_cast()) { - lhs_et = ref_type.RemoveRef(); - if (lhs_et == rhs_et) return true; - } - } - - if (lhs_et.getKind() != rhs_et.getKind()) return false; - - // If either is not type that contain subtypes then the element types don't - // match. - auto lhs_wst = lhs_et.dyn_cast(); - auto rhs_wst = rhs_et.dyn_cast(); - if (!lhs_wst || !rhs_wst) return false; - - // Consider the subtype recursively. - auto lhs_wst_st = lhs_wst.GetSubtypes(); - auto rhs_wst_st = rhs_wst.GetSubtypes(); - if (lhs_wst_st.empty() || rhs_wst_st.empty()) return true; - if (lhs_wst_st.size() != rhs_wst_st.size()) return false; - for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) { - if (!HasCompatibleElementTypes(std::get<0>(subtypes), - std::get<1>(subtypes))) - return false; +bool AreCastCompatible(ArrayRef types) { + Type common = types.front(); + for (auto type : types.drop_front()) { + Type refined_type = + GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false); + if (!refined_type) return false; + common = refined_type; } return true; } +ShapedType DropTypeSubTypes(ShapedType ty) { + Type element_ty = ty.getElementType(); + auto subtype_ty = element_ty.dyn_cast(); + if (!subtype_ty) return ty; + + Type default_ty = GetDefaultTypeOf(subtype_ty); + if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); + + return UnrankedTensorType::get(default_ty); +} + +ShapedType DropRefType(ShapedType ty) { + Type element_ty = ty.getElementType(); + TF::TensorFlowRefType ref_ty = element_ty.dyn_cast(); + if (!ref_ty) return ty; + + Type default_ty = TF::GetDefaultTypeOf(ref_ty); + if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); + + return UnrankedTensorType::get(default_ty); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index d1e6a74a0c5..5f108e834a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -313,6 +313,22 @@ bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs); bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs = false); +// Returns true if all TensorFlow types can be cast to one +// another. In other words, a single run-time value is legal for both the types. +// For example, tensor<*xf32>, tensor and tensor<3xf32> are cast +// compatible. +bool AreCastCompatible(ArrayRef types); + +// If the given tensor has elements of type with subtypes, then returns a new +// type after dropping subtypes info. Otherwise, returns the original type as +// is. +ShapedType DropTypeSubTypes(ShapedType ty); + +// If the given tensor has elements of type ref, then returns a new type +// of the shape, but corresponding non-ref type as element type. Otherwise, +// returns the original type as is. +ShapedType DropRefType(ShapedType ty); + } // end namespace TF } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 18f8d5f4486..a77aa5b8346 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -258,6 +258,59 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi // CHECK: return %arg0 } +// CHECK-LABEL: testSelectScalarPred +func @testSelectScalarPred(%arg0: tensor, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + return %0: tensor<4x2xf16> +} + +// CHECK-LABEL: testSelectVectorPred +func @testSelectVectorPred(%arg0: tensor<2xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const" + // CHECK-NEXT: %[[PRED:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xi1>, tensor<2xi64>) -> tensor<2x1xi1> + // CHECK-NEXT: "tf.SelectV2"(%[[PRED]], %arg1, %arg2) : (tensor<2x1xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectAllSameShape +func @testSelectAllSameShape(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// If we don't have guarantees on input shapes, we can't support canonicalizing +// to SelectV2. Test these cases. +// CHECK-LABEL: testSelectInvalid +func @testSelectInvalid(%arg0: tensor, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectInvalidUnranked +func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectThenUnranked +func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectElseUnranked +func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + // CHECK-LABEL: testLogicalNotOfEqual func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> @@ -471,3 +524,22 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor { // CHECK: return [[VAL0]] return %0 : tensor } + +// CHECK-LABEL: @foldFill +func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex>) { + %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tf.Const"() {value = dense<23.0> : tensor} : () -> tensor + // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} + %2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor) -> tensor<3x2x1xf32> + // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} + %3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor) -> tensor<*xf32> + + %complex_cst = "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor>} : () -> tensor> + // Here, custom folder doesn't handle complex dtypes and it is folded through + // the constant folding hook. + // TODO(hinsu): Handle complex dtypes in the custom folder for FillOp. + // CHECK: "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<3x2x1xcomplex>} : () -> tensor<*xcomplex> + %4 = "tf.Fill"(%0, %complex_cst) : (tensor<3xi32>, tensor>) -> tensor<*xcomplex> + + return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index bccb8923134..3ae6023400c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -302,15 +302,13 @@ func @testTensorListElementShape(%arg0: tensor>>) -> return %0: tensor<2xi32> } -func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialAdd - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { @@ -331,26 +329,22 @@ func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> } -func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialAddV2 - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } -func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialSub(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialSub - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { @@ -362,26 +356,31 @@ func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { // CHECK-NEXT: return %arg0 : tensor<2x2xi8> } -func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialMul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<1.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Mul"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialMul - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } -func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<1.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialDiv - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialRealDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { @@ -411,28 +410,35 @@ func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { // CHECK: tf.Div } -func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> { +func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: DontRemoveTrivialAdd // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> - // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK: return %[[RESULT]] : tensor<2x2xf32> } -func @DontRemoveTrivialAdd2(%arg0: tensor, %arg1: tensor<2x2xf32>) -> tensor { +func @DontRemoveTrivialAdd2(%arg0: tensor) -> tensor { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<2x2xf32>) -> tensor - %1 = "tf.AddV2"(%0, %cst) : (tensor , tensor<2x2xf32>) -> tensor - return %1 :tensor + %0 = "tf.AddV2"(%arg0, %cst) : (tensor , tensor<2x2xf32>) -> tensor + return %0 :tensor // CHECK-LABEL: DontRemoveTrivialAdd2 // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> - // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<2x2xf32>) -> tensor - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor, tensor<2x2xf32>) -> tensor + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor, tensor<2x2xf32>) -> tensor // CHECK: return %[[RESULT]] : tensor } + +// Test no fold because of the broadcast. +func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> { + %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor) -> tensor<1x6x8x1xf32> + return %1 : tensor<1x6x8x1xf32> + // CHECK-LABEL: DontRemoveTrivialMul + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + // CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor) -> tensor<1x6x8x1xf32> + // CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir new file mode 100644 index 00000000000..de6f9b42ba4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir @@ -0,0 +1,47 @@ +// RUN: tf-opt %s -tf-executor-tpu-v1-island-outlining | FileCheck %s --dump-input=fail + +// CHECK: func @control_input +// CHECK-NOT: func @ +// CHECK-LABEL: module @_tpu_v1_compat_outlined +// CHECK: @_tpu_v1_compat_outlined_func0 +// CHECK: func @branch_0 +// CHECK: func @branch_1 +// CHECK: func @branch_2 +// CHECK: func @branch_3 +// CHECK: func @branch_4 +module { + func @control_input(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %output, %control = tf_executor.island { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> () + %index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor, tensor) -> tensor + tf_executor.yield %result : tensor + } + tf_executor.fetch %output : tensor + + } + return %0 : tensor + } + func @branch_0(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_1(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_2(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_3(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } + func @branch_4(%arg0: tensor) -> tensor { + %0 = "tf.some_op"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor + return %0 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir new file mode 100644 index 00000000000..cd3b8b55032 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir @@ -0,0 +1,50 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics + +// Tests invalid #tf.func attributes. + +// expected-error@+1 {{invalid TensorFlow func attribute: func}} +func @main() attributes {tf._implements = #tf.func} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<>}} +func @main() attributes {tf._implements = #tf.func<>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol>}} +func @main() attributes {tf._implements = #tf.func<@symbol>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<{}>}} +func @main() attributes {tf._implements = #tf.func<{}>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<"test", {}>}} +func @main() attributes {tf._implements = #tf.func<"test", {}>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, "">}} +func @main() attributes {tf._implements = #tf.func<@symbol, "">} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, {}, "">}} +func @main() attributes {tf._implements = #tf.func<@symbol, {}, "">} { + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir new file mode 100644 index 00000000000..de17778c105 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir @@ -0,0 +1,13 @@ +// RUN: tf-opt %s | tf-opt | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @func_attr +// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}> +func @func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}>} { + return +} + +// CHECK-LABEL: func @nested_func_attr +// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.000000e+00 : f32}>}> +func @nested_func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.0 : f32}>}>} { + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt new file mode 100644 index 00000000000..9f044c62736 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt @@ -0,0 +1,53 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input-on-failure + +node { + name: "custom_relu_func_call" + op: "custom_relu" +} +node { + name: "custom_embedding_matmul_func_call" + op: "custom_embedding_matmul" +} +library { + function { + signature { + name: "custom_relu" + } + attr { + key: "_implements" + value { + func { + name: "tensorflow.relu" + } + } + } + } + function { + signature { + name: "custom_embedding_matmul" + } + attr { + key: "_implements" + value { + func { + name: "tensorflow.embedding_matmul" + attr { + key: "key1" + value { + i: 2 + } + } + attr { + key: "key2" + value { + b: false + } + } + } + } + } + } +} + +# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>} +# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt index d26585edb03..03640e24aac 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt @@ -13,7 +13,7 @@ # CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island wraps "tf.StatefulPartitionedCall" # CHECK-SAME: f = @[[FUNC:[a-z0-9]*]] # CHECK: tf_executor.fetch %[[ISLAND_1]], %[[ISLAND_2]] : tensor<*xf32>, tensor<*xf32> -# CHECK: func @[[FUNC]](%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32> +# CHECK: func @[[FUNC]](%arg0: tensor<*xf32> {tf._user_specified_name = "inputs"}, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32> node { name: "args_0" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt index 0e6e561225d..eb358d52b26 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt @@ -81,7 +81,7 @@ library { } # Check that the `resource_arg_unique_id` for each argument is propagated to the -# `tf.resource_arg_unique_id` argument attribute of the function +# `tf._resource_arg_unique_id` argument attribute of the function # @test_func_name0. # CHECK: func @main @@ -92,8 +92,8 @@ library { # CHECK: tf_executor.fetch # CHECK: return # CHECK: func @test_func_name0 -# CHECK-SAME: tf.resource_arg_unique_id = 0 -# CHECK-SAME: tf.resource_arg_unique_id = 0 +# CHECK-SAME: tf._resource_arg_unique_id = 0 +# CHECK-SAME: tf._resource_arg_unique_id = 0 # CHECK: tf_executor.graph # CHECK: tf_executor.fetch # CHECK: return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 10cb4f8019d..198227bf5dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -2,17 +2,17 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } @@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor) -> tensor { func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { %0 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %2 = xla_hlo.constant dense<0> : tensor<3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %8 = xla_hlo.constant dense<1> : tensor<3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %14 : tensor<2x3xi32> } @@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 %0 = xla_hlo.constant dense<0> : tensor<3xi32> %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %2 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %8 = xla_hlo.constant dense<1> : tensor<2x3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> @@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> return %2 : tensor<2x3xf16> } @@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> { func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } func @relu_unranked(%arg0: tensor) -> tensor { %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %1 : tensor } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> return %3 : tensor<1xi32> } func @relu6_unranked(%arg0: tensor) -> tensor { %0 = xla_hlo.constant dense<0> : tensor %1 = xla_hlo.constant dense<6> : tensor - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %3 : tensor } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor + %1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor %2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> %3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return %3 : tensor<4x8xf32> @@ -682,6 +682,37 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { + %0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> + return %0 : tensor<1x519xf32> +} + +func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32> + return %0 : tensor<2x2x6xf32> + +} + +func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> { + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> { + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor { + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor + return %0 : tensor +} + +func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> { + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} + // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK-LABEL: func @biasAdd_NHWC( @@ -1493,3 +1524,49 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK: [[VAL_371:%.*]] = "tf.Cast"([[VAL_370]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> // CHECK: return [[VAL_371]] : tensor<2xf32> // CHECK: } + +// CHECK-LABEL: func @convert_slice( +// CHECK-SAME: [[VAL_372:%.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> { +// CHECK: [[VAL_373:%.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_374:%.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_375:%.*]] = "tf.Slice"([[VAL_372]], [[VAL_373]], [[VAL_374]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32> +// CHECK: return [[VAL_375]] : tensor<1x519xf32> +// CHECK: } + +// CHECK-LABEL: func @reshape( +// CHECK-SAME: [[VAL_372:%.*]]: tensor<4x6xf32>) -> tensor<2x2x6xf32> { +// CHECK: [[VAL_373:%.*]] = constant dense<[2, 2, 6]> : tensor<3xi64> +// CHECK: [[VAL_374:%.*]] = "tf.Reshape"([[VAL_372]], [[VAL_373]]) : (tensor<4x6xf32>, tensor<3xi64>) -> tensor<2x2x6xf32> +// CHECK: return [[VAL_374]] : tensor<2x2x6xf32> +// CHECK: } + +// CHECK-LABEL: func @convert_dot_1d_2d( +// CHECK-SAME: [[VAL_376:%.*]]: tensor<256xf32>, [[VAL_377:%.*]]: tensor<256x1xf32>) -> tensor<1xf32> { +// CHECK: [[VAL_378:%.*]] = "tf.Reshape"([[VAL_376]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: [[VAL_379:%.*]] = "tf.MatMul"([[VAL_378]], [[VAL_377]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: [[VAL_380:%.*]] = "tf.Reshape"([[VAL_379]], {{.*}}) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return [[VAL_380]] : tensor<1xf32> +// CHECK: } + +// CHECK-LABEL: func @convert_dot_2d_1d( +// CHECK-SAME: [[VAL_381:%.*]]: tensor<1x256xf32>, [[VAL_382:%.*]]: tensor<256xf32>) -> tensor<1xf32> { +// CHECK: [[VAL_383:%.*]] = "tf.Reshape"([[VAL_382]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: [[VAL_384:%.*]] = "tf.MatMul"([[VAL_381]], [[VAL_383]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> +// CHECK: [[VAL_385:%.*]] = "tf.Reshape"([[VAL_384]], {{.*}}) : (tensor<1x1xf32>, tensor<1xi64>) -> tensor<1xf32> +// CHECK: return [[VAL_385]] : tensor<1xf32> +// CHECK: } + +// CHECK-LABEL: func @convert_dot_1d_1d( +// CHECK-SAME: [[VAL_386:%.*]]: tensor<256xf32>, [[VAL_387:%.*]]: tensor<256xf32>) -> tensor { +// CHECK-DAG: [[VAL_388:%.*]] = "tf.Reshape"([[VAL_386]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK-DAG: [[VAL_389:%.*]] = "tf.Reshape"([[VAL_387]], {{.*}}) : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: [[VAL_390:%.*]] = "tf.MatMul"([[VAL_388]], [[VAL_389]]) {transpose_a = false, transpose_b = true} : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x1xf32> +// CHECK: [[VAL_391:%.*]] = "tf.Reshape"([[VAL_390]], {{.*}}) : (tensor<1x1xf32>, tensor<0xi64>) -> tensor +// CHECK: return [[VAL_391]] : tensor +// CHECK: } + +// CHECK-LABEL: func @convert_dot_2d_2d( +// CHECK-SAME: [[VAL_392:%.*]]: tensor<1x256xf32>, [[VAL_393:%.*]]: tensor<256x1xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAL_394:%.*]] = "tf.MatMul"([[VAL_392]], [[VAL_393]]) {transpose_a = false, transpose_b = false} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAL_394]] : tensor<1x1xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir index 680e26f5cbb..44824ea1424 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/function-resource-args.mlir @@ -8,14 +8,14 @@ func @main() -> tensor<*x!tf.resource> attributes {tf.entry_function = {inputs = } return %0 : tensor<*x!tf.resource> } -func @test_func_name0(%arg0: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}, %arg1: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}) -> tensor<*x!tf.resource> attributes {tf._disable_call_shape_inference = true} { +func @test_func_name0(%arg0: tensor<*x!tf.resource> {tf._resource_arg_unique_id = 0 : i64}, %arg1: tensor<*x!tf.resource> {tf._resource_arg_unique_id = 0 : i64}) -> tensor<*x!tf.resource> attributes {tf._disable_call_shape_inference = true} { %0 = tf_executor.graph { tf_executor.fetch %arg0 : tensor<*x!tf.resource> } return %0 : tensor<*x!tf.resource> } -// Check that the `tf.resource_arg_unique_id` argument attributes of +// Check that the `tf._resource_arg_unique_id` argument attributes of // test_func_name0 are propagated to the function's arg_attr and // resource_arg_unique_id. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index e7f4873594b..59c93a66d12 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -1,11 +1,11 @@ // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-resources-to-args | FileCheck %s -dump-input-on-failure // One resource, one read. The initial value of the resource is read. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD]]) // CHECK: return %[[PACK]] %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor @@ -19,8 +19,8 @@ func @main() -> tensor<2xf32> { // ----- // One resource, one write. The initial value of the resource is not read. -// CHECK-LABEL: func @main() -> (tensor {tf.resource_name = "x"}) -func @main() { +// CHECK-LABEL: func @main(%arg0: tensor) -> (tensor {tf.resource_name = "x"}) +func @main(%arg0: tensor) { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.AssignVariableOp" // CHECK: return %[[CONST]] @@ -33,12 +33,12 @@ func @main() { // ----- // One resource, two reads using different resource handles. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) - // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg0) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg1) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]]) // CHECK: return %[[PACK]] @@ -56,12 +56,12 @@ func @main() -> tensor<2xf32> { // ----- // Two resources, two reads using different resources. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}, %arg1: tensor {tf.resource_name = "y"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}, %arg2: tensor {tf.resource_name = "y"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) - // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg1) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg2) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]]) // CHECK: return %[[PACK]] @@ -79,12 +79,12 @@ func @main() -> tensor<2xf32> { // ----- // One resource with read and write. The initial value of the resource is read. -// CHECK-LABEL: func @main(%arg0: tensor {tf.aliasing_output = 1 : i64, tf.resource_name = "x"}) -> (tensor<2xf32>, tensor) -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.aliasing_output = 1 : i64, tf.resource_name = "x"}) -> (tensor<2xf32>, tensor) +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.AssignVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %{{[0-9]*}}) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %{{[0-9]*}}) // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %[[ADD1]]) - // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%arg0, %[[ADD2]]) + // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%arg1, %[[ADD2]]) // CHECK: return %[[PACK]], %[[ADD1]] %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor @@ -102,8 +102,8 @@ func @main() -> tensor<2xf32> { // ----- // One resource with read and write. The initial value of the resource is not read. -// CHECK-LABEL: func @main() -> (tensor<2xf32>, tensor {tf.resource_name = "x"}) -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor) -> (tensor<2xf32>, tensor {tf.resource_name = "x"}) +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.AssignVariableOp" // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<4.200000e+01> : tensor} // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%[[CONST]], %[[CONST]]) @@ -138,15 +138,15 @@ func @cond_true(%arg0: tensor>>, %arg1: tensor) -> return %2 : tensor } -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outputs = "result"}} { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outputs = "result"}} { %0 = "tf.Const"() {value = dense<1.050000e+03> : tensor} : () -> tensor %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> %2 = "tf.ReadVariableOp"(%1) : (tensor>>) -> tensor %3 = "tf.Less"(%2, %0) : (tensor, tensor) -> tensor %4 = "tf.If"(%3, %1, %2) {Tcond = i1, Tin = ["tfdtype$DT_RESOURCE", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], - else_branch = @cond_false, is_stateless = false, output_shapes = [#tf.shape<>], - then_branch = @cond_true} : (tensor, tensor>>, tensor) -> tensor + else_branch = @cond_false, is_stateless = false,then_branch = @cond_true} : + (tensor, tensor>>, tensor) -> tensor %5 = "tf.Identity"(%4) : (tensor) -> tensor %6 = "tf.Pack"(%2, %5) {N = 2 : i64, T = f32, axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xf32> return %6 : tensor<2xf32> @@ -157,10 +157,11 @@ func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outp // Tests resource passed in as an argument is not modified and not returned. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor -func @main(%arg0: tensor>>) { - %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor - // CHECK-NEXT: "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) +// CHECK-SAME: %arg0: tensor +// CHECK-SAME: %[[ARG_1:[a-z0-9]+]]: tensor +func @main(%arg0: tensor, %arg1: tensor>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + // CHECK-NEXT: "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) %1 = "tf.AddV2"(%0, %0) : (tensor, tensor) -> tensor // CHECK-NEXT: return return @@ -171,9 +172,10 @@ func @main(%arg0: tensor>>) { // Tests resource passed in as an argument is modified but not returned. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> tensor -func @main(%arg0: tensor>>) { +func @main(%arg0: tensor>>, %arg1: tensor) { // CHECK-NEXT: %[[CONST:[a-z0-9]+]] = "tf.Const" %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () @@ -186,9 +188,10 @@ func @main(%arg0: tensor>>) { // Tests last resource assign is returned as a result. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> tensor -func @main(%arg0: tensor>>) { +func @main(%arg0: tensor>>, %arg1: tensor) { %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} @@ -204,9 +207,10 @@ func @main(%arg0: tensor>>) { // returns the same value prior. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> (tensor, tensor) -func @main(%arg0: tensor>>) -> tensor { +func @main(%arg0: tensor>>, %arg1: tensor) -> tensor { %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} @@ -221,9 +225,10 @@ func @main(%arg0: tensor>>) -> tensor { // Tests read interleaved between writes. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> (tensor, tensor) -func @main(%arg0: tensor>>) -> tensor { +func @main(%arg0: tensor>>, %arg1: tensor) -> tensor { // CHECK-NEXT: %[[CONST_0:[a-z0-9]+]] = "tf.Const"() {value = dense<4.200000e+01> : tensor} %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () @@ -271,7 +276,7 @@ func @main(%arg0: tensor>>, %arg1: tensor>>) -> tensor { %0 = "tf.VarIsInitializedOp"(%arg0) : (tensor>>) -> tensor + %1 = "tf.UnknownOp"(%arg0) : (tensor>>) -> tensor return %0 : tensor } @@ -323,7 +329,7 @@ func @main(%arg0: tensor>>) -> tensor { // Tests VarHandleOp has users that are not removed. func @main() -> tensor { - // expected-error@+1 {{expects no uses but used by operations: tf.UnknownOp, tf.VarIsInitializedOp}} + // expected-error@+1 {{expects users to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got [tf.UnknownOp, tf.VarIsInitializedOp]}} %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> %1 = "tf.VarIsInitializedOp"(%0) : (tensor>>) -> tensor %2 = "tf.UnknownOp"(%0) : (tensor>>) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir new file mode 100644 index 00000000000..8b8a070cfab --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir @@ -0,0 +1,59 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure + +// Tests main function with multiple blocks. + +// expected-error@+1 {{expects function 'main' to have 1 block, got 2}} +func @main() { + br ^bb1 +^bb1: + return +} + +// ----- + +// CHECK-LABEL: func @no_args +// CHECK-SAME: (%arg0: tensor {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @no_args() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + return +} + +// CHECK-LABEL: func @some_args +// CHECK-SAME: (%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @some_args(%arg0: tensor) { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + return +} + +// CHECK-LABEL: func @unique_vars +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}, %arg1: tensor>> {tf.resource_name = "y"}) +// CHECK-NOT: "tf.VarHandleOp" +func @unique_vars() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "y"} : () -> tensor>> + return +} + +// CHECK-LABEL: func @duplicate_vars +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @duplicate_vars() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + return +} + +// CHECK-LABEL: func @duplicate_vars_with_users +// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf.resource_name = "x"}) +// CHECK: "tf.ReadVariableOp"(%arg1) +// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0) +// CHECK-NOT: "tf.VarHandleOp" +func @duplicate_vars_with_users(%arg0: tensor) { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor + %2 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + "tf.AssignAddVariableOp"(%2, %arg0) : (tensor>>, tensor) -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir new file mode 100644 index 00000000000..2970e31c3c9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir @@ -0,0 +1,85 @@ +// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail + +// Test case: Basic converting. + +func @f() { + // CHECK: "tf.VarHandleOp" + // CHECK: "tf.ReadVariableOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: Two ReadVariable ops. + +func @f() { + // CHECK: "tf.VarHandleOp" + + // During lowering to resource variables, this pass will preserve the + // locations of the ReadVariableOps as Identity ops to keep the original graph + // composition and order. + + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.ReadVariableOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + %val2 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No follow-up ReadVariable case. + +func @f() { + // CHECK-NOT: "tf.VariableV2" + // CHECK-NOT: "tf.VarHandleOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + return +} + +// ----- + +// Test case: No converting when there is another use case. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects all users to be 'tf.Identity', but got user tf.CustomIdentity}} + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.CustomIdentity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No class attribute on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op has no '_class' attribute}} + %val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No named location found on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects variable name in '_class' attribute, but got ["unrelated_class"]}} + %val0 = "tf.VariableV2"() {_class = ["unrelated_class"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: Invalid multiple location information in a class attribute on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects only one named location in '_class' attribute, but got ["loc:@v1", "loc:@v2"]}} + %val0 = "tf.VariableV2"() {_class = ["loc:@v1", "loc:@v2"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index c8b4ad2cb9f..28c542cded1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -119,3 +119,60 @@ func @replicate_control() { // CHECK: %[[REPLICA_1:.*]] = tf_executor.island // CHECK: %[[SINK:.*]] = tf_executor.island(%[[REPLICA_0]], %[[REPLICA_1]]) // CHECK: tf_executor.fetch %[[SINK]] + + +// Tests replicate results are remapped correctly. +// CHECK-LABEL: func @replicate_result +func @replicate_result(%arg0: tensor, %arg1: tensor) { + %0:4 = tf_executor.graph { + %1:5 = tf_executor.island { + %2:4 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {n = 2 : i32} { + %3 = "tf.opA"(%arg2) : (tensor) -> tensor + %4 = "tf.opB"(%arg2) : (tensor) -> tensor + tf_device.return %3, %4 : tensor, tensor + } + tf_executor.yield %2#0, %2#1, %2#2, %2#3 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } + return +} + +// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1 + + +// Tests replica id is added correctly. +// CHECK-LABEL: func @replica_id_attr_added +func @replica_id_attr_added(%arg0: tensor, %arg1: tensor) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {n = 2 : i32} { + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor) -> () + "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor) -> () + "tf.A"(%arg2) : (tensor) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: tf_executor.island +// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch" +// CHECK-SAME: _xla_replica_id = 0 +// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch" +// CHECK-SAME: _xla_replica_id = 0 +// CHECK: "tf.A" +// CHECK-NOT: _xla_replica_id +// CHECK: tf_executor.island +// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch" +// CHECK-SAME: _xla_replica_id = 1 +// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch" +// CHECK-SAME: _xla_replica_id = 1 +// CHECK: "tf.A" +// CHECK-NOT: _xla_replica_id +// CHECK: tf_executor.fetch diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir index c98e40fed05..60eded3de7e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir @@ -217,7 +217,7 @@ func @error_on_conflict_multiple_callers( // expected-error@above {{Conflicting device assignment for resource}} then_branch = @if_then_and_else, else_branch = @if_then_and_else, - output_shapes = [], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> () tf_executor.yield diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 9e7358ab2f5..2353dc5a7a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -406,6 +406,61 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // ----- +// CHECK: func @cluster_with_case(%[[ARG0:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_case(%arg0: tensor) -> tensor<4xf32> { + // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) + // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + %2 = "tf_device.cluster"() ( { + // CHECK: %[[CASE:.*]]:2 = "tf.Case"(%[[ARG0]], %[[READ0]], %[[READ1]]) + %3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2]} + : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + -> (tensor<*x!tf.resource>>, tensor<4xf32>) + // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0) + %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: tf_device.return %[[ADD]], %[[CASE]]#1 + tf_device.return %5 : tensor<4xf32> + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 + return %2 : tensor<4xf32> +} +// CHECK: func @branch_0(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>) +func @branch_0(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) + -> (tensor<*x!tf.resource>>, tensor<4xf32>) { + // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() + %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> + "tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + // CHECK-NEXT: return %[[CONST]], %[[CONST]] + return %arg0, %constant : tensor<*x!tf.resource>>, tensor<4xf32> +} +// CHECK: func @branch_1(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>) +func @branch_1(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) + -> (tensor<*x!tf.resource>>, tensor<4xf32>) { + %id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + // CHECK-NEXT: return %[[EARG1]], %[[EARG1]] + return %arg0, %read : tensor<*x!tf.resource>>, tensor<4xf32> +} +// CHECK: func @branch_2(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>) +func @branch_2(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) + -> (tensor<*x!tf.resource>>, tensor<4xf32>) { + %id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + "tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource>>, tensor<4xf32>) -> () + // CHECK-NEXT: return %[[EARG1]], %[[EARG1]] + return %arg0, %read : tensor<*x!tf.resource>>, tensor<4xf32> +} + +// ----- + // Tests that pass lifts resource reads from if branches. // CHECK: func @cluster_with_if(%[[ARG0:.*]]: tensor) -> tensor<4xf32> @@ -420,7 +475,7 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]]) %3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = [#tf.shape<>, #tf.shape<4>], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<4xf32>) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0) @@ -468,7 +523,7 @@ func @cluster_with_nested_if(%arg0: tensor) -> tensor { %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]]) %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = [], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]], %[[IF]]) @@ -488,7 +543,7 @@ func @if_then(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.re // CHECK-NEXT: %[[IIF:.*]] = "tf.If"(%[[TARG0]], %[[TARG0]]) %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor %3 = "tf.If"(%read, %arg0) {then_branch = @inner_if_then, else_branch = @inner_if_else, - output_shapes = [], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK-NEXT: return %[[IIF]] @@ -524,9 +579,9 @@ func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> %2 = "tf_device.cluster"() ( { - // expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}} + // expected-error @+1 {{unsupported output: resource does not alias a single input}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, - output_shapes = [#tf.shape<>], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) %4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource>>) -> tensor<4xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 160bba94cfc..e3766a7d9d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1,10 +1,11 @@ -// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail +// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants=false -verify-diagnostics | FileCheck %s -dump-input=fail +// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants -verify-diagnostics | FileCheck %s -dump-input=fail module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { - // CHECK-NOT: tf.Cast - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2" + // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: return %[[RESULT]] : tensor<1xi32> %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> @@ -60,8 +61,8 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @simple_folding func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor { -// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]] +// CHECK: %[[SHAPE:.*]] = "tf.Shape" +// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> // CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> @@ -101,7 +102,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @shape_from_if_to_branch_functions func @shape_from_if_to_branch_functions(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } @@ -183,16 +184,16 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @invalid_function_reused_by_control_flows func @invalid_function_reused_by_control_flows(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - // expected-warning @+1 {{unable to refine shape}} - %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> - // expected-warning @+1 {{unable to refine shape}} - %1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = [#tf.shape<>], then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // expected-warning @+1 {{unable to refine shape}} + %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // expected-warning @+1 {{unable to refine shape}} + %1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", then_branch = @reused_if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } // CHECK-LABEL: func @reused_if_then_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}} func @reused_if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return // CHECK-SAME: tensor<*xf32> @@ -201,7 +202,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @reused_if_else_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}} func @reused_if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) @@ -300,13 +301,6 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor<*xi32> } - // CHECK-LABEL: func @fold_cast - func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NOT: Cast - %0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) - return %0 : tensor<*xf32> - } - // CHECK-LABEL: func @while_variant // CHECK-SAME: -> tensor>> func @while_variant(%arg0: tensor>>) -> tensor { @@ -362,8 +356,6 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @partitioned_call_func_const func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CONST]] return %arg0 : tensor<2xi32> } @@ -410,4 +402,18 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor return } + + // CHECK-LABEL: const_fold + func @const_fold() -> () { + // CHECK: tf.Const + // CHECK-SAME: () -> tensor<4xi32> + %0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> + // CHECK: tf.Const + // CHECK-SAME: () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> + // CHECK: tf.Add + // CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index c36771c0576..965b3b10843 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -786,9 +786,9 @@ func @tf_registry_ops( // CHECK-LABEL: func @arguments_with_unique_ids func @arguments_with_unique_ids( // expected-remark@above {{ID: 9}} - %arg0: tensor<*x!tf.resource>> {tf.resource_arg_unique_id = 0 : i64}, - %arg1: tensor<*x!tf.resource>> {tf.resource_arg_unique_id = 0 : i64}, - %arg2: tensor<*x!tf.resource>> {tf.resource_arg_unique_id = 33 : i64}) { + %arg0: tensor<*x!tf.resource>> {tf._resource_arg_unique_id = 0 : i64}, + %arg1: tensor<*x!tf.resource>> {tf._resource_arg_unique_id = 0 : i64}, + %arg2: tensor<*x!tf.resource>> {tf._resource_arg_unique_id = 33 : i64}) { tf_executor.graph { // expected-remark@above {{ID: 7}} // expected-remark@above {{Successors: {8}}} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 118ce2e8645..8692104e772 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -854,6 +854,265 @@ func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { // ----- +// Test invalid tf.Yield operation (parent should be IfRegion) +func @testInvalidYieldOp(%arg0: f32) -> () { + // expected-error @+1 {{expects parent op 'tf.IfRegion'}} + "tf.Yield"(%arg0) : (f32) -> () +} + +// ----- + +// Test valid tf.IfRegion operation +// CHECK-LABEL: func @testValidIfRegionOp +func @testValidIfRegionOp(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%neg) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// Test valid tf.IfRegion operation with multiple results +// CHECK-LABEL: func @testValidIfRegionOpWithMultipleResults +func @testValidIfRegionOpWithMultipleResults(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0, %1, %2 = "tf.IfRegion"(%arg0) ({ + %t0 = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %t1 = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %t2 = "tf.Acosh"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t0, %t1, %t2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> () + }, { + %e0 = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %e1 = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %e2 = "tf.Sin"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e0, %e1, %e2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + + %3 = "tf.Add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %4 = "tf.Add"(%2, %3) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %4 : tensor<2xf32> +} + +// ----- + +// Test invalid type for operand #0 for tf.IfRegion operation +func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{operand #0 must be tensor of tf.dtype values}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (f32) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// tf.IfRegion operation should have 2 regions +func @testInvalidIfRegionOp1Region(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{op expected 2 regions}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testInvalidIfRegionOpNoRegions(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{op expected 2 regions}} + %0 = "tf.IfRegion"(%arg0) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testInvalidIfRegionOp3Regions(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{op expected 2 regions}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %te = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%te) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// tf.IfRegion regions should be terminated with a tf.Yield +func @testIfRegionThenTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}} + // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}} + // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// tf.Region yield number of results should match op number of results +func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{then region should have 1 result}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{else region should have 1 result}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e, %e) : (tensor<2xf32>, tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// tf.IfRegion yield types should match op result types +func @testIfRegionOpYieldMismatchThen(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{then result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + %0 = "tf.IfRegion"(%arg0) ({ + "tf.Yield"(%arg0) : (tensor) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionOpYieldMismatchElse(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{else result type tensor is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + "tf.Yield"(%arg0) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// value generated in one branch cannot be consumed in the other branch +func @testIfRegionElseConsumingThen(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + // expected-error @+1 {{use of undeclared SSA value name}} + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionThenConsumingElse(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.IfRegion"(%arg0) ({ + // expected-error @+1 {{does not dominate this use}} + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + // expected-note @+1 {{operand defined here}} + %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// The regions for IfRegion themselves cannot have any arguments +func @testInvalidIfRegionThenArg(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + // expected-error @+1 {{then region cannot have any arguments}} + %0 = "tf.IfRegion"(%arg0) ({ + ^bb(%arg_bb: tensor<2xf32>): + %t = "tf.Abs"(%arg_bb) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%neg) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testInvalidIfRegionElseArg(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %neg = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + // expected-error @+1 {{else region cannot have any arguments}} + %0 = "tf.IfRegion"(%arg0) ({ + %t = "tf.Abs"(%neg) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + ^bb(%arg_bb: tensor<2xf32>): + %e = "tf.Acos"(%arg_bb) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + // Test valid tf.MatrixBandPart // CHECK-LABEL: func @testValidMatrixBandPartOp func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { @@ -881,20 +1140,29 @@ func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { - // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} - %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> - return %0 : tensor<64x64xbf16> +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOpUnrankedBand +func @testValidMatrixBandPartOpUnrankedBand(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<*xbf16> { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<*xbf16> + return %0 : tensor<*xbf16> +} + +// ----- + +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOpCompatibleDynamicShapes +func @testValidMatrixBandPartOpCompatibleDynamicShapes(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor } // ----- // Test invalid tf.MatrixBandPart -func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<*xbf16> { - // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} - %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<*xbf16> - return %0 : tensor<*xbf16> +func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { + // expected-error @+1 {{op failed to verify that all of {input, band} have dynamically equal types}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> } // ----- @@ -998,6 +1266,116 @@ func @pcall_func_2(%arg0: tensor, %arg1: tensor) -> tensor { // ----- +//===--------------------------------------------------------------------===// +// tf.Select +//===--------------------------------------------------------------------===// + +// Test valid tf.Select +// CHECK-LABEL: func @testSelect +func @testSelect(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + +func @testInvalidSelect(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // expected-error @+1 {{requires that, when pred is a vector, the shape matches the first dimension of t and e}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// ----- + +// Test invalid tf.Select - broadcasting then/else parameters is not supported +func @selectBroadcastThen(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + // expected-error @+1 {{requires t and e have compatible shapes}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<2xi1>, %arg1: tensor, %arg2: tensor) -> tensor<2xi32> { + // expected-error @+1 {{requires that t and e are nonscalar when pred is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor, tensor) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: tensor<1x8x8xi32>) -> tensor<1x8x8xi32> { + // expected-error @+1 {{requires that pred is a scalar OR has the same rank as t and e OR is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<1x8xi1>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<1x8x8xi32> + return %0: tensor<1x8x8xi32> +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.SelectV2 +//===--------------------------------------------------------------------===// + +// Test valid tf.SelectV2 +// CHfaECK-LABEL: func @selectV2BroadcastThen +func @selectV2BroadcastThen(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastElse +func @selectV2BroadcastElse(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastPred +func @selectV2BroadcastPred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2BroadcastAll +func @selectV2BroadcastAll(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> + return %0: tensor<8x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2DynamicRanked +func @selectV2DynamicRanked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> + return %0: tensor<2x?x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2Unranked +func @selectV2Unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + +// ----- + +// Test invalid tf.SelectV2: this is an invalid broadcast for the predicate +func @testInvalidSelectV2(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + // expected-error @+1 {{operands don't have broadcast-compatible shapes}} + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + //===--------------------------------------------------------------------===// // tf.Softmax //===--------------------------------------------------------------------===// @@ -1317,7 +1695,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -1361,7 +1739,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}} + // expected-warning @+1 {{has static shape result #1 for unranked operand #1}} %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor, tensor<2xi32>) return %0#1 : tensor<2xi32> } @@ -1419,7 +1797,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource>>) -> tensor<2xi32> { - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>) -> tensor<2xi32> return %0 : tensor<2xi32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py index 78c18a17d4a..b337224e680 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic.py @@ -45,7 +45,7 @@ class TestModule(tf.Module): # CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "[[VAR:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = ["v42"], type = tensor, value = dense<4.200000e+01> : tensor} : () -> () # CHECK: "tf_saved_model.global_tensor"() {sym_name = "[[CONST:[a-zA-Z_0-9]+]]", tf_saved_model.exported_names = [], type = tensor, value = dense<4.300000e+01> : tensor} : () -> () # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, + # CHECK-SAME: %arg0: tensor {tf._user_specified_name = "x", tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor>> {tf_saved_model.bound_input = @[[VAR]]}, # CHECK-SAME: %arg2: tensor>> {tf_saved_model.bound_input = @[[CONST]]}) -> ( # CHECK-SAME: tensor {tf_saved_model.index_path = []}) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py index 658cc37a22f..694942f4b00 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/call_to_exported.py @@ -45,7 +45,7 @@ class TestModule(tf.Module): # modify signatures interprocedurally). # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, + # CHECK-SAME: %arg0: tensor {tf._user_specified_name = "x", tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} # CHECK-SAME: ) -> ( # CHECK-SAME: tensor {tf_saved_model.index_path = [0]}, @@ -54,7 +54,7 @@ class TestModule(tf.Module): # CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL:[a-zA-Z_0-9]+]] # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0]}, + # CHECK-SAME: %arg0: tensor {tf._user_specified_name = "x", tf_saved_model.index_path = [0]}, # CHECK-SAME: %arg1: tensor> {tf_saved_model.bound_input = {{@[a-zA-Z_0-9]+}}} # CHECK-SAME: ) -> ( # CHECK-SAME: tensor {tf_saved_model.index_path = [0]}, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/defun_export.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/defun_export.py new file mode 100644 index 00000000000..8bd128898a0 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/defun_export.py @@ -0,0 +1,63 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/defun_export | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 +from tensorflow.python.framework import function + + +@function.Defun(tf.float32, tf.float32) +def plus(a, b): + return a + b + + +def test_defun(): + x = tf.constant([[1.0], [1.0], [1.0]]) + y = tf.constant([[2.0], [2.0], [2.0]]) + + # Verify that the function defined using function.Defun + # has a corresponding tf.LegacyCall op. + # CHECK: func {{@[a-zA-Z_0-9]+}}( + # CHECK-SAME: [[ARG0:%.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["y"]}, + # CHECK-SAME: [[ARG1:%.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["x"]} + # + # CHECK-NEXT: [[R0:%.*]] = "tf.LegacyCall"([[ARG1]], [[ARG0]]) + z = plus(x, y) + + tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) + tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y) + tensor_info_z = tf.compat.v1.saved_model.utils.build_tensor_info(z) + + return { + 'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={ + 'x': tensor_info_x, + 'y': tensor_info_y + }, + outputs={'z': tensor_info_z}, + method_name='test_function')) + } + + +if __name__ == '__main__': + common_v1.set_tf_options() + common_v1.do_test(test_defun()) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/keras.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/keras.py index a95909b61ef..ffb5c024bbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/keras.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/keras.py @@ -39,7 +39,7 @@ class TestModule(tf.Module): super(TestModule, self).__init__() self.model = mnist_model() - # CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor<1x28x28x1xf32> {tf_saved_model.index_path = [0]} + # CHECK: func {{@[a-zA-Z_0-9]+}}(%arg0: tensor<1x28x28x1xf32> {tf._user_specified_name = "x", tf_saved_model.index_path = [0]} # CHECK: attributes {{.*}} tf_saved_model.exported_names = ["my_predict"] @tf.function(input_signature=[ tf.TensorSpec([1, 28, 28, 1], tf.float32), diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_input.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_input.py index 095fddbda96..43591d12183 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_input.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_input.py @@ -36,8 +36,8 @@ class TestModule(tf.Module): # The outer layer of the index path indexes into the arguments. # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor<1xf32> {tf_saved_model.index_path = [0]}, - # CHECK-SAME: %arg1: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: %arg0: tensor<1xf32> {tf._user_specified_name = "x", tf_saved_model.index_path = [0]}, + # CHECK-SAME: %arg1: tensor<2xf32> {tf._user_specified_name = "y", tf_saved_model.index_path = [1]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0000_function_arity"] @tf.function(input_signature=[ tf.TensorSpec([1], tf.float32), @@ -49,8 +49,8 @@ class TestModule(tf.Module): # Check index paths for lists. # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor {tf_saved_model.index_path = [0, 0]}, - # CHECK-SAME: %arg1: tensor {tf_saved_model.index_path = [0, 1]}) + # CHECK-SAME: %arg0: tensor {tf._user_specified_name = "l", tf_saved_model.index_path = [0, 0]}, + # CHECK-SAME: %arg1: tensor {tf._user_specified_name = "l", tf_saved_model.index_path = [0, 1]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0001_list_2_elements"] @tf.function(input_signature=[[ tf.TensorSpec([], tf.float32), @@ -63,8 +63,8 @@ class TestModule(tf.Module): # Keys are linearized in sorted order, matching `tf.nest.flatten`. # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor<1xf32> {tf_saved_model.index_path = [0, "x"]}, - # CHECK-SAME: %arg1: tensor<2xf32> {tf_saved_model.index_path = [0, "y"]}) + # CHECK-SAME: %arg0: tensor<1xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "x"]}, + # CHECK-SAME: %arg1: tensor<2xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "y"]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0002_dict_2_keys"] @tf.function(input_signature=[{ 'x': tf.TensorSpec([1], tf.float32), @@ -77,8 +77,8 @@ class TestModule(tf.Module): # The index path should be insensitive to the key order. # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor<1xf32> {tf_saved_model.index_path = [0, "x"]}, - # CHECK-SAME: %arg1: tensor<2xf32> {tf_saved_model.index_path = [0, "y"]}) + # CHECK-SAME: %arg0: tensor<1xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "x"]}, + # CHECK-SAME: %arg1: tensor<2xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "y"]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0003_dict_2_keys_out_of_order"] @tf.function(input_signature=[{ 'y': tf.TensorSpec([2], tf.float32), @@ -90,12 +90,12 @@ class TestModule(tf.Module): # Slightly stronger stress test of multiple dict keys. # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor<1xf32> {tf_saved_model.index_path = [0, "a"]}, - # CHECK-SAME: %arg1: tensor<2xf32> {tf_saved_model.index_path = [0, "b"]}, - # CHECK-SAME: %arg2: tensor<3xf32> {tf_saved_model.index_path = [0, "c"]}, - # CHECK-SAME: %arg3: tensor<4xf32> {tf_saved_model.index_path = [0, "x"]}, - # CHECK-SAME: %arg4: tensor<5xf32> {tf_saved_model.index_path = [0, "y"]}, - # CHECK-SAME: %arg5: tensor<6xf32> {tf_saved_model.index_path = [0, "z"]}) + # CHECK-SAME: %arg0: tensor<1xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "a"]}, + # CHECK-SAME: %arg1: tensor<2xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "b"]}, + # CHECK-SAME: %arg2: tensor<3xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "c"]}, + # CHECK-SAME: %arg3: tensor<4xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "x"]}, + # CHECK-SAME: %arg4: tensor<5xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "y"]}, + # CHECK-SAME: %arg5: tensor<6xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "z"]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0004_dict_many_keys"] @tf.function(input_signature=[{ 'x': tf.TensorSpec([4], tf.float32), @@ -112,9 +112,9 @@ class TestModule(tf.Module): # Note that list elements can have heterogenous types. # # CHECK: func {{@[a-zA-Z_0-9]+}}( - # CHECK-SAME: %arg0: tensor<1xf32> {tf_saved_model.index_path = [0, "x", 0]}, - # CHECK-SAME: %arg1: tensor<2xf32> {tf_saved_model.index_path = [0, "x", 1]}, - # CHECK-SAME: %arg2: tensor<3xf32> {tf_saved_model.index_path = [0, "y"]}) + # CHECK-SAME: %arg0: tensor<1xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "x", 0]}, + # CHECK-SAME: %arg1: tensor<2xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "x", 1]}, + # CHECK-SAME: %arg2: tensor<3xf32> {tf._user_specified_name = "d", tf_saved_model.index_path = [0, "y"]}) # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_more_complex_recursive_structure"] @tf.function(input_signature=[{ 'x': [tf.TensorSpec([1], tf.float32), diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index 38aa078358b..961039e7968 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -104,3 +104,9 @@ module attributes {tf_saved_model.semantics} { return } } + +// ----- + +// Test running the pass on a module that does not have +// tf_saved_model.semantics. +module {} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir index f985be16ab8..80d9a498253 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir @@ -136,3 +136,9 @@ module attributes {tf_saved_model.semantics} { } } + +// ----- + +// Test running the pass on a module that does not have +// tf_saved_model.semantics. +module {} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index 77ca08c089a..d5fb821b5e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -1,30 +1,388 @@ // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure -// Tests extraction of a single outside compiled cluster with no input or output dependecies. +// Tests extraction of a outside compiled ops at head of TPU computation. -// CHECK-LABEL: func @nodep_single_head_outside_compilation -func @nodep_single_head_outside_compilation() -> () { - // CHECK: "tf.A" - // CHECK-NEXT: "tf_device.launch" - "tf_device.launch"() ( { - "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.B"() : () -> () - "tf.C"() : () -> () - tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - return -} +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @head_single_outside_compiled_op + func @head_single_outside_compiled_op(%arg0: tensor) { + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.B"() : () -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } -// CHECK-LABEL: func @nodep_multiple_head_outside_compilation -func @nodep_multiple_head_outside_compilation() -> () { - // CHECK: "tf.A" - // CHECK-NEXT: "tf.B" - // CHECK-NEXT: "tf_device.launch" - "tf_device.launch"() ( { - "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.C"() : () -> () - tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - return + // CHECK-LABEL: func @head_single_outside_compiled_op_no_operands + func @head_single_outside_compiled_op_no_operands() { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %a = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> tensor + %b = "tf.B"(%a) : (tensor) -> tensor + "tf.C"(%b) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @head_operand_op_outside_cluster + func @head_operand_op_outside_cluster() { + // CHECK: %[[A_OUT:.*]] = "tf.A" + %a = "tf.A"() : () -> tensor + // CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.C"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: "tf.D" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %c = "tf.C"(%b) : (tensor) -> tensor + "tf.D"(%c) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @head_aliased_output + func @head_aliased_output() -> (tensor, tensor, tensor) { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[B_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + %cluster:3 = "tf_device.cluster"() ( { + %a = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> tensor + %b = "tf.B"(%a) : (tensor) -> tensor + %c = "tf.C"(%b) : (tensor) -> tensor + tf_device.return %a, %c, %b : tensor, tensor, tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor) + // CHECK: return %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#0, %[[CLUSTER_OUT]]#1 + return %cluster#0, %cluster#1, %cluster#2 : tensor, tensor, tensor + } + + // CHECK-LABEL: func @head_all_cluster_op + func @head_all_cluster_op(%arg0: tensor) -> tensor { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: tf_device.return + %cluster = "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %c = "tf.C"(%b, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + tf_device.return %c : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor + // CHECK: return %[[LAUNCH_OUT]] + return %cluster : tensor + } + + // CHECK-LABEL: func @head_multiple_outside_compiled_ops + func @head_multiple_outside_compiled_ops(%arg0: tensor) { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.D"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.C"(%b, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + "tf.D"(%b) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @head_replicated_outside_compilation + func @head_replicated_outside_compilation(%arg0: tensor, %arg1: tensor) { + // CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor) + // + // CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { + "tf_device.cluster"() ( { + %a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.B"(%a) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } + + // CHECK-LABEL: func @tail_single_outside_compiled_op + func @tail_single_outside_compiled_op() { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @tail_single_outside_compiled_op_user + func @tail_single_outside_compiled_op_user() -> tensor { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster = "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.C"() : () -> () + tf_device.return %b : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor + // CHECK: return %[[LAUNCH_OUT]] + return %cluster : tensor + } + + // CHECK-LABEL: func @tail_multiple_outside_compiled_ops + func @tail_multiple_outside_compiled_ops(%arg0: tensor) { + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: tf_device.return %[[B_OUT]], %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: "tf_device.launch" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1) + // CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + %b = "tf.B"(%arg0) : (tensor) -> tensor + %c = "tf.C"(%arg0, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + "tf.D"(%c, %arg0, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @tail_aliased_output + func @tail_aliased_output() -> (tensor, tensor, tensor, tensor, tensor) { + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + %a = "tf.A"() : () -> tensor + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + %b = "tf.B"() : () -> tensor + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: %[[E_OUT:.*]] = "tf.E" + // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster:5 = "tf_device.cluster"() ( { + %c = "tf.C"() : () -> tensor + %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %e = "tf.E"() : () -> tensor + tf_device.return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor, tensor, tensor) + // CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1 + return %cluster#0, %cluster#1, %cluster#2, %cluster#3, %cluster#4 : tensor, tensor, tensor, tensor, tensor + } + + // CHECK-LABEL: func @tail_replicated_outside_compilation + func @tail_replicated_outside_compilation(%arg0: tensor, %arg1: tensor) { + // CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor) + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK-NEXT: "tf_device.launch"() + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { + "tf_device.cluster"() ( { + %a = "tf.A"(%ri) : (tensor) -> tensor + %b = "tf.B"(%a, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } + + // CHECK-LABEL: func @head_tail_no_extraction_middle_outside_compiled_ops + func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor) { + // CHECK-NOT: "tf_device.launch" + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) : (tensor) -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.C"(%b) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @head_tail_simple_extraction + func @head_tail_simple_extraction(%arg0: tensor) -> tensor { + // CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster = "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"(%a) : (tensor) -> tensor + %c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + tf_device.return %c : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor + // CHECK: return %[[TAIL_LAUNCH_OUT]] + return %cluster : tensor + } + + // CHECK-LABEL: func @head_tail_replicated_outside_compilation + func @head_tail_replicated_outside_compilation(%arg0: tensor, %arg1: tensor) { + // CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor) + // + // CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]]) + // CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK-NEXT: "tf_device.launch"() + // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { + "tf_device.cluster"() ( { + %a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"() : () -> tensor + %c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> tensor + %e = "tf.E"(%c, %a) : (tensor, tensor) -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index b2e8f116827..0c4de285b16 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -3,12 +3,12 @@ // Tests that missing `_xla_outside_compilation` attribute value results in an error. func @missing_outside_compilation_attribute() -> () { - "tf_device.launch"() ( { + "tf_device.cluster"() ( { "tf.A"() : () -> () // expected-error@+1 {{attribute '_xla_outside_compilation' is empty}} "tf.B"() {_xla_outside_compilation = ""} : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -18,11 +18,11 @@ func @missing_outside_compilation_attribute() -> () { // CHECK-LABEL: func @no_outside_compilation func @no_outside_compilation() -> tensor { - %0 = "tf_device.launch"() ( { + %0 = "tf_device.cluster"() ( { %1 = "tf.A"() : () -> tensor %2 = "tf.B"(%1) : (tensor) -> tensor tf_device.return %2 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor return %0 : tensor } @@ -36,16 +36,15 @@ func @nodep_single_outside_compilation() -> () { // CHECK-NEXT: "tf_device.launch" // CHECK-NEXT: "tf.B" // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf_device.launch" + // CHECK: "tf_device.cluster" // CHECK-NEXT: "tf.A" - // CHECK: device = "tpu0" - // CHECK-SAME: launch_attr = "launch_attr" - "tf_device.launch"() ( { + // CHECK: cluster_attr = "cluster_attr" + "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.C"() : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -59,19 +58,18 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () { // CHECK-NEXT: "tf.C" // CHECK-NEXT: "tf.D" // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf_device.launch" + // CHECK: "tf_device.cluster" // CHECK-NEXT: "tf.A" // CHECK-NEXT: "tf.E" - // CHECK: device = "tpu0" - // CHECK-SAME: launch_attr = "launch_attr" - "tf_device.launch"() ( { + // CHECK: cluster_attr = "cluster_attr" + "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.E"() : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -80,15 +78,16 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () { // CHECK-LABEL: func @nodep_multiple_outside_compilation func @nodep_multiple_outside_compilation() -> () { // CHECK: "tf_device.parallel_execute" - // CHECK-COUNT-3: "tf_device.launch" - "tf_device.launch"() ( { + // CHECK-COUNT-2: "tf_device.launch" + // CHECK: "tf_device.cluster" + "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.C"() : () -> () "tf.D"() {_xla_outside_compilation = "cluster2"} : () -> () "tf.E"() : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -97,20 +96,20 @@ func @nodep_multiple_outside_compilation() -> () { // CHECK-LABEL: func @single_tpu_return_single_outside_compilation func @single_tpu_return_single_outside_compilation(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate - // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" - // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch" - // CHECK: tf_device.return - // CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]] - // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster" + // CHECK: tf_device.return + // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] + // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { - %2 = "tf_device.launch"() ( { + %2 = "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () %3 = "tf.C"() : () -> tensor tf_device.return %3 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor tf_device.return %2 : tensor } @@ -122,24 +121,321 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor) -> tens // CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation func @multiple_tpu_return_single_outside_compilation(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate - // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute" - // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]]:2 = "tf_device.launch" - // CHECK: tf_device.return - // CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]] - // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] + // CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: tf_device.return + // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] + // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] %1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { - %2, %3 = "tf_device.launch"() ( { + %2, %3 = "tf_device.cluster"() ( { %4 = "tf.A"() : () -> tensor "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () %5 = "tf.C"() : () -> tensor tf_device.return %4, %5 : tensor, tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor, tensor) + }) {cluster_attr = "cluster_attr"} : () -> (tensor, tensor) tf_device.return %2, %3 : tensor, tensor } return %1 : tensor } -// TODO(b/154363171): Add test cases for when output of outside compilation is returned by parallel_execute. +// Tests extraction of a single outside compiled cluster with single device->host input. + +// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation +func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.B"(%[[RECV_OUTPUT]]) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + %4 = "tf.C"() : () -> tensor + tf_device.return %4 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a single outside compiled cluster with single host->device output. + +// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation +func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"() + // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"() + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.C"(%[[HOST_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor) + %5 = "tf.C"(%4) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a single outside compiled cluster host output returned by TPU cluster. + +// CHECK-LABEL: func @return_host_output_outside_compilation +func @return_host_output_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: tf_device.return %[[HOST_OUTPUT]] + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %5 = "tf.C"(%3) : (tensor) -> (tensor) + tf_device.return %4 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a single outside compiled cluster with single input/output. + +// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation +func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.C"(%[[HOST_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %5 = "tf.C"(%4) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + + +// Tests extraction of a single outside compiled cluster with multiple input/output. + +// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation +func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1) + // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.D"(%[[HOST_OUTPUT]]#0) + // CHECK: "tf.E"(%[[HOST_OUTPUT]]#1) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor, tensor) + %7 = "tf.D"(%5) : (tensor) -> tensor + %8 = "tf.E"(%6) : (tensor) -> tensor + tf_device.return %8 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a multiple outside compiled clusters with input/output. + +// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation +func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]]) + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]]) + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]]) + // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]]) + // CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK: "tf.E"(%[[HOST_OUTPUT2]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %5 = "tf.C"(%4) : (tensor) -> (tensor) + %6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor) -> (tensor) + %7 = "tf.E"(%6) : (tensor) -> tensor + tf_device.return %7 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a single outside compiled cluster with arg input and single device->host input. + +// CHECK-LABEL: func @mixed_input_single_outside_compilation +func @mixed_input_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]]) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + "tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + %4 = "tf.C"() : () -> tensor + tf_device.return %4 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a multiple outside compiled clusters with single device->host input. + +// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation +func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]]) + // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK: "tf.D"(%[[RECV_OUTPUT_2]]) + // CHECK: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.B"(%[[RECV_OUTPUT_1]]) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster2" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + %4 = "tf.C"() : () -> tensor + "tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor) -> () + tf_device.return %4 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} + +// Tests extraction of a single outside compiled cluster with multiple device->host inputs. + +// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation +func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.C"(%[[RECV_OUTPUT]]#0) + // CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + "tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index af0119dab8f..5d65342b4a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -747,7 +747,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.cluster_func` on TPU with replication. +// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under +// data parallelism replicated host devices are also added to the +// tf_device.replicate module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { // CHECK-LABEL: func @replicated_tpu_cluster_func @@ -758,7 +760,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor) - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]} // CHECK-SAME: n = 2 %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]]) @@ -1222,6 +1224,51 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- +// Tests simple case of `tf_device.cluster_func` on TPU with replication and +// parallel_execute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { + // CHECK-LABEL: func @replicated_parallel_tpu_cluster_func + func @replicated_parallel_tpu_cluster_func(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK: "tf._TPUCompileMlir" + // CHECK: "tf.TPUCompileSucceededAssert" + // CHECK: "tf_device.parallel_execute" + // CHECK-NOT:"tf._TPUCompileMlir" + // CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1 + // CHECK: "tf.TPUExecute" + // CHECK-NOT:"tf._TPUCompileMlir" + // CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1 + %3 = "tf_device.parallel_execute"() ( { + %status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor) + "tf.D"(%program) : (tensor) -> () + tf_device.return + }, { + %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + tf_device.return %4 : tensor + }, { + %status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor) + "tf.E"(%program) : (tensor) -> () + tf_device.return + }) : () -> (tensor) + tf_device.return %3 : tensor + } + %2 = "tf.C"(%1#1) : (tensor) -> tensor + return %2 : tensor + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + // Tests devices are set properly for non replicated model parallelism. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { @@ -1282,15 +1329,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01" // ----- -// Tests devices are set properly for replicated model parallelism. +// Tests devices are set properly for replicated model parallelism. No +// replicated host device should be present. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @replicated_parallel_execute func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) { // CHECK: tf_device.replicate - // CHECK-SAME: devices = - // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] - // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]} %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() @@ -1322,8 +1368,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests that inputs are inputs with maximal and replicate sharding are set properly -// for replicated model parallelism. +// Tests that inputs are inputs with maximal and replicate sharding are set +// properly for replicated model parallelism. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations @@ -1357,8 +1403,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests devices are set properly for replicated model parallelism with -// outputs to TPU computation placed on logical device 0. +// Tests devices are set properly for replicated model parallelism with outputs +// to TPU computation placed on logical device 0. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_different_outputs @@ -1434,8 +1480,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests inputs are correctly split and fed into TPU computation for -// tiled input sharding. +// Tests inputs are correctly split and fed into TPU computation for tiled input +// sharding. // The following OpSharding is used for TPU computation inputs in below test: // Proto debug string: diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir new file mode 100644 index 00000000000..aa333caa2ae --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir @@ -0,0 +1,87 @@ +// RUN: tf-opt %s -split-input-file -tf-tpu-space-to-depth-pass | FileCheck %s --dump-input=fail + +// Tests for space to depth host and device transform. + +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 390 : i32}} { + func @main(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg5: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg6: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) attributes {tf.entry_function = {control_outputs = "while", inputs = "iterator,iterator_1,iterator_2,iterator_3,while_input_6,while_input_7,while_input_8,while_input_9", outputs = ""}} { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) + return + } + // CHECK-LABEL: func @while_body_2710 + func @while_body_2710(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg8: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg9: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) attributes {tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[INPUT:.*]] = "tf.IteratorGetNext" + %1 = "tf.IteratorGetNext"(%arg5) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor<2x224x224x3xf32> + // CHECK-DAG: %[[SPACETODEPTH0:.*]] = "tf.SpaceToDepth"([[INPUT:.*]]) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> + %2 = "tf.AddV2"(%arg2, %arg3) {device = ""} : (tensor, tensor) -> tensor + %3 = "tf.ReadVariableOp"(%arg6) : (tensor>>) -> tensor<7x7x3x64xf32> + %4 = "tf.ReadVariableOp"(%arg8) : (tensor>>) -> tensor + %5 = "tf.ReadVariableOp"(%arg7) : (tensor>>) -> tensor + %6 = "tf.ReadVariableOp"(%arg9) : (tensor>>) -> tensor + %7:2 = "tf_device.cluster_func"(%1, %3, %5, %6) {_tpu_replicate = "while/cluster_while_body_271", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<7x7x3x64xf32>, tensor, tensor) -> (tensor<7x7x3x64xf32>, tensor) + "tf.AssignVariableOp"(%arg6, %7#0) : (tensor>>, tensor<7x7x3x64xf32>) -> () + "tf.AssignVariableOp"(%arg9, %7#1) : (tensor>>, tensor) -> () + %8 = "tf.Identity"(%arg1) {device = ""} : (tensor) -> tensor + %9 = "tf.Identity"(%2) {device = ""} : (tensor) -> tensor + %10 = "tf.AddV2"(%arg0, %0) {device = ""} : (tensor, tensor) -> tensor + %11 = "tf.Identity"(%10) {device = ""} : (tensor) -> tensor + return %11, %8, %9, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 : tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>> + } + func @while_cond_2700(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg8: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg9: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %1 = "tf.GreaterEqual"(%arg3, %0) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.Less"(%arg3, %0) {device = ""} : (tensor, tensor) -> tensor + %3 = "tf.Greater"(%arg2, %arg4) {device = ""} : (tensor, tensor) -> tensor + %4 = "tf.LogicalAnd"(%2, %3) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.Less"(%arg2, %arg4) {device = ""} : (tensor, tensor) -> tensor + %6 = "tf.LogicalAnd"(%1, %5) {device = ""} : (tensor, tensor) -> tensor + %7 = "tf.LogicalOr"(%6, %4) {device = ""} : (tensor, tensor) -> tensor + %8 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor + %9 = "tf.LogicalAnd"(%8, %7) {device = ""} : (tensor, tensor) -> tensor + %10 = "tf.Identity"(%9) {device = ""} : (tensor) -> tensor + return %10 : tensor + } + // CHECK-LABEL: func @_func + // CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + func @_func(%arg0: tensor<2x224x224x3xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32> + %2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32> + %3 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + %4 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %5 = "tf.Pad"(%arg0, %3) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32> + // CHECK: "tf.Conv2D" + // CHECK-SAME: strides = [1, 1, 1, 1] + // CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32> + %6 = "tf.Conv2D"(%5, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32> + // CHECK: %[[BACKPROP:.*]] = "tf.Conv2DBackpropFilter" + // CHECK-SAME: strides = [1, 1, 1, 1] + // CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<4x4x12x64xf32> + %7 = "tf.Conv2DBackpropFilter"(%5, %2, %6) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<7x7x3x64xf32> + // CHECK: %[[CONST0:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [4, 4, 2, 2, 3, 64] + // CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"(%[[BACKPROP:.*]], %[[CONST0:.*]]) : (tensor<4x4x12x64xf32>, tensor<6xi64>) -> tensor<4x4x2x2x3x64xf32> + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [0, 2, 1, 3, 4, 5] + // CHECK: %[[TRANSPOSE0:.*]] = "tf.Transpose"(%[[RESHAPE0:.*]], %[[CONST1:.*]]) : (tensor<4x4x2x2x3x64xf32>, tensor<6xi32>) -> tensor<4x2x4x2x3x64xf32> + // CHECK: %[[CONST2:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [8, 8, 3, 64] + // CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"(%[[TRANSPOSE1:.*]], %[[CONST2:.*]]) : (tensor<4x2x4x2x3x64xf32>, tensor<4xi64>) -> tensor<8x8x3x64xf32> + // CHECK: %[[CONST3:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [7, 7, 3, 64] + // CHECK: %[[CONST4:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: 0 + // CHECK: %[[SLICE0:.*]] = "tf.Slice"(%[[RESHAPE1:.*]], %[[CONST4:.*]], %[[CONST3:.*]]) : (tensor<8x8x3x64xf32>, tensor<4xi64>, tensor<4xi32>) -> tensor<7x7x3x64xf32> + %8 = "tf.CrossReplicaSum"(%7, %1) : (tensor<7x7x3x64xf32>, tensor<1x1xi32>) -> tensor<7x7x3x64xf32> + %9 = "tf.Mul"(%arg2, %8) : (tensor, tensor<7x7x3x64xf32>) -> tensor<7x7x3x64xf32> + %10 = "tf.Sub"(%arg1, %9) : (tensor<7x7x3x64xf32>, tensor<7x7x3x64xf32>) -> tensor<7x7x3x64xf32> + %11 = "tf.AddV2"(%arg3, %0) : (tensor, tensor) -> tensor + return %10, %11 : tensor<7x7x3x64xf32>, tensor + } +} + +// ---- + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index ccc3e83a2a2..cf09f8d64fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -152,6 +152,23 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), (replaceWithValue $arg)>; +//===----------------------------------------------------------------------===// +// Select op patterns. +//===----------------------------------------------------------------------===// + +def ReshapeSelectPredIfNecessary : NativeCodeCall< + "ReshapeSelectPredIfNecessary(&($_builder), $0.getOwner()->getLoc(), $1, " + "$2.getType().cast().getRank())">; + +// Select supports tensor `condition` where the shape is equal to the first +// dimension of t and e. SelectV2 op supports normal broadcasting, so in these +// cases the condition needs to be reshaped. +def SelectToSelectV2 : Pat< + (TF_SelectOp:$op StaticShapeTensorOf<[AnyType]>:$cond, + StaticShapeTensorOf<[AnyType]>:$t, + StaticShapeTensorOf<[AnyType]>:$e), + (TF_SelectV2Op (ReshapeSelectPredIfNecessary $op, $cond, $t), $t, $e)>; + //===----------------------------------------------------------------------===// // Square op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index be35c6caa16..55a0b5c3fd3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index 08645333d5d..e04f6bf3daa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -49,6 +49,16 @@ struct TPUBridgeExecutorIslandOutlining void runOnOperation() override; }; +// Move FuncOp referenced by `symbol_ref` from one symbol table to another. +void MoveFuncOp(FlatSymbolRefAttr &symbol_ref, SymbolTable &from, + SymbolTable &to) { + if (to.lookup(symbol_ref.getValue())) return; + FuncOp callee = from.lookup(symbol_ref.getValue()); + callee.getOperation()->getBlock()->getOperations().remove( + callee.getOperation()); + to.insert(callee); +} + void TPUBridgeExecutorIslandOutlining::runOnOperation() { MLIRContext *ctx = &getContext(); @@ -141,14 +151,17 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() { for (FuncOp func : outlined_module.getOps()) { func.walk([&](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - auto symbol_ref = attr.second.dyn_cast(); - if (!symbol_ref) continue; - if (outlined_symbol_table.lookup(symbol_ref.getValue())) + if (auto symbol_ref = attr.second.dyn_cast()) { + MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); continue; - FuncOp callee = symbol_table.lookup(symbol_ref.getValue()); - callee.getOperation()->getBlock()->getOperations().remove( - callee.getOperation()); - outlined_symbol_table.insert(callee); + } + if (auto array_attr = attr.second.dyn_cast()) { + for (const Attribute &attribute : array_attr) { + auto symbol_ref = attribute.dyn_cast(); + if (!symbol_ref) continue; + MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); + } + } } }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index d3b064f3efa..a0cf9c8eb9a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -48,6 +48,9 @@ struct FreezeGlobalTensorsPass void FreezeGlobalTensorsPass::runOnOperation() { auto module = getOperation(); + if (!tf_saved_model::HasTfSavedModelSemantics(module)) { + return; + } SymbolTable symbol_table(module); DenseSet frozen_global_tensors; @@ -66,7 +69,9 @@ void FreezeGlobalTensorsPass::runOnOperation() { // previous optimize global tensors pass). If not, this pass has to fail // since it cannot perform one of its goals. if (global_tensor.is_mutable()) { - global_tensor.emitError() << "is not immutable"; + global_tensor.emitError() << "is not immutable, try running " + "tf-saved-model-optimize-global-tensors " + "to prove tensors are immutable"; return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 50f77cd9c3d..d635d605607 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -16,21 +16,64 @@ limitations under the License. // This file implements logic for legalizing HLO to TensorFlow. #include +#include +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" namespace mlir { namespace TF { namespace { +class ConvertSliceOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_hlo::SliceOp slice_op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + DenseIntElementsAttr strides = slice_op.strides(); + // Strides must be 1 otherwise we cannot legalize this `xla_hlo.slice` op. + if (!strides.isSplat() || + strides.getSplatValue().cast().getInt() != 1) + return failure(); + + rewriter.setInsertionPointAfter(slice_op); + auto start_indices = slice_op.start_indices(); + auto limit_indices = slice_op.limit_indices(); + std::vector size_values; + for (auto pair : llvm::zip(start_indices.getValues(), + limit_indices.getValues())) { + size_values.emplace_back(std::get<1>(pair).getSExtValue() - + std::get<0>(pair).getSExtValue()); + } + + RankedTensorType ty = + RankedTensorType::get({static_cast(size_values.size())}, + rewriter.getIntegerType(64)); + auto start = rewriter.create(slice_op.getLoc(), start_indices); + auto size = rewriter.create( + slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values)); + rewriter.replaceOpWithNewOp(slice_op, slice_op.getType(), + slice_op.operand(), start, size); + return success(); + }; +}; + class LegalizeHloToTf : public PassWrapper { public: LegalizeHloToTf() = default; @@ -54,6 +97,68 @@ static bool AreBroadcastCompatible(Value x, Value y) { y_ranked.getShape(), resultShape); } +// Returns the shape of the given value in a Constant Op. +ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) { + ArrayRef shape = value.getType().cast().getShape(); + auto attr_type = RankedTensorType::get({static_cast(shape.size())}, + rewriter.getIntegerType(64)); + auto attr = DenseElementsAttr::get(attr_type, shape); + return rewriter.create(value.getLoc(), attr_type, attr); +} + +// Converts xla_hlo.dot to tf.MatMul. Reshape ops will be inserted when +// necessary. +Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) { + auto dot_op = cast(old_op); + const mlir::Location loc = dot_op.getLoc(); + // Normalizes a ShapedType to 2d if the ShapedType is less than 2d by + // inserting dummy 1-element dimensions in the begining. Does nothing if the + // old shape is already 2d or higher. This is necessary because tf.MatMul + // requires input tensors to be at least 2d. + const auto normalize_rank = [](ShapedType type) -> ShapedType { + if (type.getRank() >= 2) { + return type; + } + + const int rank = type.getRank(); + llvm::SmallVector shape_2d(type.getShape().begin(), + type.getShape().end()); + for (int i = 0; i < 2 - rank; ++i) { + shape_2d.insert(shape_2d.begin(), 1); + } + return RankedTensorType::get(shape_2d, type.getElementType()); + }; + + // Reshapes a tensor value to 2d if it is 1d or scalar. Otherwise does + // nothing. + const auto reshape_to_2d = [&rewriter, &loc, + &normalize_rank](mlir::Value input) { + const auto input_type = input.getType().cast(); + if (input_type.getRank() >= 2) { + return input; + } + + auto reshape = rewriter.create( + loc, normalize_rank(input_type), input); + return reshape.getResult(); + }; + + // Reshapes both operand to be 2d for tf.MatMul op. + auto a = reshape_to_2d(dot_op.lhs()); + auto b = reshape_to_2d(dot_op.rhs()); + // Operand `b` needs to be transposed if it is 1d. This is because dot op will + // contract on the only dimension if rhs is 1d. + auto b_old_type = dot_op.rhs().getType().cast(); + BoolAttr transpose_b = rewriter.getBoolAttr(b_old_type.getRank() == 1); + auto output_type = dot_op.getResult().getType().cast(); + auto matmul = rewriter.create( + loc, normalize_rank(output_type), a, b, + /*transpose_a=*/rewriter.getBoolAttr(false), transpose_b); + auto reshape = + rewriter.create(loc, output_type, matmul.product()); + return reshape.getResult(); +} + #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc" /// Performs the lowering to XLA dialect. @@ -63,12 +168,15 @@ void LegalizeHloToTf::runOnFunction() { // Add legalization patterns to the list. OwningRewritePatternList patterns; populateWithGenerated(&context, &patterns); + patterns.insert(&context); ConversionTarget target(context); target.addLegalDialect(); - target.addLegalOp(); - if (failed(applyPartialConversion(getFunction(), target, patterns))) + target.addLegalOp(); + if (failed(applyPartialConversion(getFunction(), target, patterns))) { + getFunction().emitError("xla_hlo to TF legalization failed."); signalPassFailure(); + } } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index f3371989b73..e4b6a28d65f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -18,49 +18,65 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" -def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; - -//===----------------------------------------------------------------------===// -// Binary op patterns. -//===----------------------------------------------------------------------===// - // Check that two values can be broadcasted together // TODO(jpienaar): Move somewhere more general def AreBroadcastCompatible : Constraint, "types must be broadcastable">; -foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op], - [HLO_DivOp, TF_DivOp], - [HLO_ShiftLeftOp, TF_LeftShiftOp], - [HLO_MaxOp, TF_MaximumOp], - [HLO_MinOp, TF_MinimumOp], - [HLO_MulOp, TF_MulOp], - [HLO_PowOp, TF_PowOp], - [HLO_SubOp, TF_SubOp], - [HLO_Atan2Op, TF_Atan2Op], - [HLO_RemOp, TF_ModOp]] in - def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r), - [(AreBroadcastCompatible $l, $r)]>; +// Return a constant op that carries the shape of the given value. +def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">; -foreach pair = [[HLO_AndOp, TF_BitwiseAndOp], - [HLO_OrOp, TF_BitwiseOrOp], - [HLO_XorOp, TF_BitwiseXorOp]] in - def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r), - [(AreBroadcastCompatible $l, $r)]>; +def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; -foreach pair = [[HLO_AndOp, TF_LogicalAndOp], - [HLO_OrOp, TF_LogicalOrOp]] in - def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r), - [(AreBroadcastCompatible $l, $r)]>; +//===----------------------------------------------------------------------===// +// Binary op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. +//===----------------------------------------------------------------------===// -def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), +foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op], + [HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp], + [HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp], + [HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp], + [HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp], + [HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp], + [HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp], + [HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp], + [HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op], + [HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in { + def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>; + def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r), + [(AreBroadcastCompatible $l, $r)]>; +} + +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp], + [HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in { + def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r), + [(AreBroadcastCompatible $l, $r)]>; +} + +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in { + def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r), + [(AreBroadcastCompatible $l, $r)]>; +} + +def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), +def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), +def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>; +def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; @@ -99,6 +115,8 @@ def : Pat<(HLO_BroadcastOp $arg, $shape), def : Pat<(HLO_TransposeOp $arg, $permutation), (TF_TransposeOp $arg, (TF_ConstOp $permutation))>; def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>; +def : Pat<(HLO_ReshapeOp:$output $input), + (TF_ReshapeOp $input, (ShapeToConst $output))>; //===----------------------------------------------------------------------===// // Ternary op patterns. @@ -117,16 +135,29 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim), //===----------------------------------------------------------------------===// // Compare op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ], - [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in - def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), + [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>; +} foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE], [TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT], [TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE], - [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in - def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), + [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>; +} + +def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, " + "$0.getDefiningOp())">; +def : Pat<(HLO_DotOp:$old_value AnyStaticShapeTensor:$lhs, + AnyStaticShapeTensor:$rhs, $precision_config), + (ConvertDotOp $old_value)>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 550100c8ebf..cd8f988fd5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -278,6 +278,10 @@ void EraseUnusedBoundInputs(ModuleOp module) { void OptimizeGlobalTensorsPass::runOnOperation() { auto module = getOperation(); + if (!tf_saved_model::HasTfSavedModelSemantics(module)) { + return; + } + EraseUnusedBoundInputs(module); ResourceAnalyzer resource_analyzer(module); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index c1d99c2dee3..5c140ddd6aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -91,6 +91,15 @@ std::unique_ptr> CreateResourceDeviceInferencePass(); // of their aliasing output arguments. std::unique_ptr> CreatePromoteResourcesToArgsPass(); +// Creates a pass that promotes tf.VarHandleOp to resource arguments for all +// functions. +std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); + +// Creates a pass that converts readonly reference variables to the +// corresponding resource variables. +std::unique_ptr> +CreateConvertReadonlyReferenceVariablesToResourceVariablesPass(); + // Marks function visibility using tf.entry_function specification. That is, // functions with tf.entry_function attributes are marked with public // visibility while the other functions are marked with private visibility. @@ -258,7 +267,7 @@ std::unique_ptr> CreateTPUVariableReformattingPass(); // Creates a pass that extracts outside compilation (CPU ops inside TPU cluster) // at head/tail of TPU cluster to run before/after TPU computation. -std::unique_ptr> +std::unique_ptr> CreateTPUExtractHeadTailOutsideCompilationPass(); // Creates a pass that extract outside compilation (CPU ops inside TPU cluster) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index fa4fe461317..cece23b4750 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -47,11 +47,14 @@ limitations under the License. // . Dead functions have already been removed, as resource arguments in dead // functions can cause the pass to fail. +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -73,114 +76,189 @@ constexpr char kResourceFunctionMsg[] = "expects function level resource argument"; constexpr char kInvalidResourceMsg[] = "expects resource to be a VarHandleOp or function argument"; +constexpr char kResourceNameArgAttr[] = "tf.resource_name"; -// Records the input argument index and the current live value for a resource -// variable. -// -// . If the input argument already exists or has been added, input_index is the -// index of the function, and live_value_or_type tracks the live value of the -// resource. -// -// . If the input argument has not been added in the pass, input_index is -// kInputUnassigned, live_value_or_type represents the type of the resource. -// (a) If this resource is read, add a new argument whose type is obtained -// from live_value_or_type, and input_index and live_value_or_type will be -// updated to reference the new argument. -// (b) If this resource is written, live_value_or_type will track the new -// value of the resource. input_index will remain to be kInputUnassigned. +// Checks if a function has only one block. +mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) { + if (!hasSingleElement(function.getBlocks())) + return function.emitError() + << "expects function '" << function.getName() + << "' to have 1 block, got " << function.getBlocks().size(); + + return success(); +} + +// Collects names of users of a resource that are not `tf.ReadVariableOp` and +// not `tf.AssignVariableOp`. +llvm::SmallSet GetCompositeResourceUserNames( + Value resource) { + // SmallSet will use a vector when there is only one element and use std::set + // when there are more than one elements. This ensures that the operations in + // the error message are ordered. + llvm::SmallSet composite_users; + for (Operation* user : resource.getUsers()) + if (!llvm::isa(user) && + !llvm::isa(user)) + composite_users.insert(user->getName().getStringRef()); + + return composite_users; +} + +// Checks if `tf.VarHandleOp` has a valid resource subtype and its users are of +// `tf.ReadVariableOp` and `tf.AssignVariableOp` only. +mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) { + auto resource_type = + getElementTypeOrSelf(var_handle_op.getType()).cast(); + if (resource_type.getSubtypes().size() != 1) + return var_handle_op.emitOpError() + << "expects resource type to have one subtype, got " + << resource_type; + + auto composite_ops = GetCompositeResourceUserNames(var_handle_op); + if (!composite_ops.empty()) + return var_handle_op.emitOpError() + << "expects users to be 'tf.ReadVariableOp' or " + "'tf.AssignVariableOp', got [" + << llvm::join(composite_ops.begin(), composite_ops.end(), ", ") + << "]"; + + return success(); +} + +// Checks if resource argument has a valid resource subtype and its users are of +// `tf.ReadVariableOp` and `tf.AssignVariableOp` only. +mlir::LogicalResult ValidateResourceArgument(FuncOp function, + BlockArgument resource_arg, + TF::ResourceType resource_type) { + if (resource_type.getSubtypes().size() != 1) + return function.emitError() + << "expects resource type of argument " + << resource_arg.getArgNumber() << " to have one subtype, got " + << resource_type; + + auto composite_ops = GetCompositeResourceUserNames(resource_arg); + if (!composite_ops.empty()) + return function.emitError() + << "expects users of resource argument " + << resource_arg.getArgNumber() + << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got [" + << llvm::join(composite_ops.begin(), composite_ops.end(), ", ") + << "]"; + + return success(); +} + +// Adds resource arguments for every unique (name) variable handle. Associated +// `tf.VarHandleOp` are removed from the function. Variable shared names are +// returned in `var_handle_shared_names` based on the ordering of added resource +// arguments. +mlir::LogicalResult PromoteVarHandlesToArguments( + FuncOp function, bool add_validation, + llvm::SmallVectorImpl* var_handle_shared_names) { + Block& block = function.front(); + auto func_type = function.getType(); + + auto func_arg_types = llvm::to_vector<4>(func_type.getInputs()); + llvm::SmallDenseMap var_arg_index_by_name; + for (auto var_handle_op : + llvm::make_early_inc_range(block.getOps())) { + if (add_validation && failed(ValidateVarHandle(var_handle_op))) + return failure(); + + llvm::StringRef name = var_handle_op.shared_nameAttr().getValue(); + auto it = var_arg_index_by_name.insert({name, func_arg_types.size()}); + if (it.second) { + var_handle_shared_names->emplace_back(name); + auto resource_type = var_handle_op.resource().getType(); + func_arg_types.push_back(resource_type); + var_handle_op.resource().replaceAllUsesWith( + block.addArgument(resource_type)); + } else { + var_handle_op.resource().replaceAllUsesWith( + block.getArgument(it.first->getSecond())); + } + var_handle_op.erase(); + } + + if (!var_handle_shared_names->empty()) + function.setType(FunctionType::get(func_arg_types, func_type.getResults(), + function.getContext())); + + return success(); +} + +// Records the current live value for a resource variable and whether a read or +// write on the variable occurred. struct ResourceInfo { - static constexpr int64_t kInputUnassigned = -1; - int64_t input_index; - llvm::PointerUnion live_value_or_type; + Value live_value = nullptr; + bool read = false; + bool write = false; }; -using ArgOrName = llvm::PointerUnion; -using ResourceMap = llvm::SmallDenseMap; - -LogicalResult PromoteResourcesToArguments(FuncOp function) { +LogicalResult PromoteResourcesToArguments( + FuncOp function, llvm::ArrayRef var_handle_shared_names) { Block& block = function.front(); auto return_op = llvm::dyn_cast_or_null(block.getTerminator()); if (!return_op) - return function.emitError( - "expects 'main' function to have a MLIR ReturnOp"); + return function.emitError() << "expects function '" << function.getName() + << "' to have a MLIR ReturnOp"; - ResourceMap resource_map; + llvm::SmallVector resources(function.getNumArguments()); auto argument_types = llvm::to_vector<4>(function.getType().getInputs()); + bool has_resources = false; + auto add_resource_argument = [&](BlockArgument arg, + TF::ResourceType resource_type) { + Type arg_type = resource_type.getSubtypes().front(); + arg.setType(arg_type); + resources[arg.getArgNumber()].live_value = arg; + argument_types[arg.getArgNumber()] = arg_type; + has_resources = true; + }; - // Loop through the resource arguments in the function and store a mapping - // from that argument to its index and itself as the current live value. - for (BlockArgument& func_arg : function.getArguments()) { + // Loop through the non `tf.VarHandleOp` resource arguments in the function, + // validate its uses and subtype, and store a mapping from that argument to + // itself as the current live value. + auto func_args = function.getArguments().take_front( + function.getNumArguments() - var_handle_shared_names.size()); + for (BlockArgument& func_arg : func_args) { auto resource_type = getElementTypeOrSelf(func_arg.getType()).dyn_cast(); if (!resource_type) continue; - if (resource_type.getSubtypes().size() != 1) - return function.emitError() - << "expects resource type of argument " << func_arg.getArgNumber() - << " to have one subtype, got " << resource_type; + if (failed(ValidateResourceArgument(function, func_arg, resource_type))) + return failure(); - for (auto* user : func_arg.getUsers()) - if (!llvm::isa(user) && - !llvm::isa(user)) - return function.emitError() - << "expects users of resource argument " - << func_arg.getArgNumber() - << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp'"; - - Type arg_type = resource_type.getSubtypes().front(); - func_arg.setType(arg_type); - resource_map[func_arg] = {func_arg.getArgNumber(), func_arg}; - argument_types[func_arg.getArgNumber()] = arg_type; + add_resource_argument(func_arg, resource_type); } - // Loop through the VarHandleOp in the function. When the first VarHandleOp - // for a resource variable is encountered, add an entry to the resource_map to - // record the information. Do not add a new function argument yet. - for (auto var_handle_op : block.getOps()) { - if (resource_map.count(var_handle_op.shared_nameAttr())) continue; - + // Loop through `tf.VarHandleOp` resource arguments in the function and store + // a mapping from that argument to itself as the current live value. No + // validations are necessary here as these arguments were validated prior to + // being added. + auto var_handle_args = + function.getArguments().take_back(var_handle_shared_names.size()); + for (BlockArgument& var_handle_arg : var_handle_args) { auto resource_type = - getElementTypeOrSelf(var_handle_op.getType()).cast(); - if (resource_type.getSubtypes().size() != 1) - return var_handle_op.emitOpError() - << "expects resource type to have one subtype, got " - << resource_type; - - resource_map[var_handle_op.shared_nameAttr()] = { - ResourceInfo::kInputUnassigned, resource_type.getSubtypes().front()}; + getElementTypeOrSelf(var_handle_arg.getType()).cast(); + add_resource_argument(var_handle_arg, resource_type); } - if (resource_map.empty()) return success(); + if (!has_resources) return success(); // We initially assign the argument for a resource as the live value for the // resource. We then walk through the operations in the function in their // lexical order, to update the live value for the resource when we see a // store to the resource and replace reads of the resource with uses of its - // live value. For the reads, if the resource does not have a live value yet, - // we add a new argument and use it as the live value. + // live value. for (Operation& op : llvm::make_early_inc_range(block)) { if (auto read_op = llvm::dyn_cast(&op)) { if (auto func_arg = read_op.resource().dyn_cast()) { if (func_arg.getOwner() != &block) return read_op.emitOpError(kResourceFunctionMsg); - // resource_map[func_arg] is always a Value when func_arg is a - // BlockArgument. - read_op.value().replaceAllUsesWith( - resource_map[func_arg].live_value_or_type.get()); - } else if (auto var_handle_op = llvm::dyn_cast( - read_op.resource().getDefiningOp())) { - ResourceInfo& info = resource_map[var_handle_op.shared_nameAttr()]; - if (auto live_value = info.live_value_or_type.dyn_cast()) { - read_op.value().replaceAllUsesWith(live_value); - } else { - auto arg_type = info.live_value_or_type.get(); - BlockArgument arg = block.addArgument(arg_type); - info.input_index = argument_types.size(); - info.live_value_or_type = arg; - argument_types.push_back(arg_type); - read_op.value().replaceAllUsesWith(arg); - } + ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; + resource_info.read = true; + read_op.value().replaceAllUsesWith(resource_info.live_value); } else { return read_op.emitOpError(kInvalidResourceMsg); } @@ -191,11 +269,9 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { if (func_arg.getOwner() != &block) return write_op.emitOpError(kResourceFunctionMsg); - resource_map[func_arg].live_value_or_type = write_op.value(); - } else if (auto var_handle_op = llvm::dyn_cast( - write_op.resource().getDefiningOp())) { - resource_map[var_handle_op.shared_nameAttr()].live_value_or_type = - write_op.value(); + ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; + resource_info.write = true; + resource_info.live_value = write_op.value(); } else { return read_op.emitOpError(kInvalidResourceMsg); } @@ -206,67 +282,68 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { const int64_t num_results_before = function.getNumResults(); auto return_operands = llvm::to_vector<4>(return_op.getOperands()); - return_operands.reserve(num_results_before + resource_map.size()); auto result_types = llvm::to_vector<4>(return_op.getOperandTypes()); - result_types.reserve(num_results_before + resource_map.size()); - llvm::SmallVector, 4> output_only_resources; - output_only_resources.reserve(resource_map.size()); + llvm::SmallVector, 4> + output_only_resources; llvm::SmallVector, 4> input_output_alias; - input_output_alias.reserve(resource_map.size()); - // Collect new return values and either (a) output-only resource attributes - // (if the resource is not promoted to an argument) or (b) mapping from - // resource input index to output alias (if the resource has been promoted to - // an argument). If the last live value is itself (argument), then that live - // value will not be returned as the resource is unmodified. - for (auto& resource : resource_map) { - int64_t input_index = resource.getSecond().input_index; - auto live_value = resource.getSecond().live_value_or_type.dyn_cast(); - if (input_index == ResourceInfo::kInputUnassigned) { - if (!live_value) continue; - - output_only_resources.push_back( - {return_operands.size(), resource.getFirst().dyn_cast()}); - } else { - // live_value is not nullptr because any input-assigned resource has a - // Value as live_value. - auto live_arg = live_value.dyn_cast(); - if (live_arg && live_arg.getOwner() == &block && - live_arg.getArgNumber() == input_index) - continue; - - input_output_alias.push_back({input_index, return_operands.size()}); - } - return_operands.push_back(live_value); - result_types.push_back(live_value.getType()); - } - - // Erase all VarHandleOp. - for (Operation& op : llvm::make_early_inc_range(function.front())) { - auto var_handle_op = llvm::dyn_cast(op); - if (!var_handle_op) continue; - if (!var_handle_op.use_empty()) { - // SmallSet will use a vector when there is only one element and use - // std::set when there are more than one elements. This ensures that - // the operations in the error message are ordered. - llvm::SmallSet unique_operations; - llvm::for_each( - var_handle_op.getOperation()->getUsers(), [&](Operation* user) { - unique_operations.insert(user->getName().getStringRef().str()); - }); - - return var_handle_op.emitOpError( - "expects no uses but used by operations: ") - << llvm::join(unique_operations.begin(), unique_operations.end(), - ", "); - } - - op.erase(); - } - - // Rewrite return if more results need to be returned by the function. + // Collect new return values for variable writes and either (a) output-only + // resource attributes (if the resource is not promoted to an argument) or (b) + // mapping from resource input index to output alias (if the resource has been + // promoted to an argument). Resource arguments that were originally + // `tf.VarHandleOp` but not read are collected and then removed. OpBuilder builder(return_op); - if (!output_only_resources.empty() || !input_output_alias.empty()) { + const int var_handles_start_idx = + function.getNumArguments() - var_handle_shared_names.size(); + int new_argument_index = 0; + llvm::SmallVector argument_indices_to_remove; + for (auto resource_and_index : llvm::enumerate(resources)) { + const auto& resource = resource_and_index.value(); + if (!resource.live_value) { + // Ignore non resource arguments. + ++new_argument_index; + continue; + } + + const auto index = resource_and_index.index(); + const bool is_var_handle = index >= var_handles_start_idx; + if (resource.write) { + if (!is_var_handle || resource.read) { + input_output_alias.push_back( + {new_argument_index, return_operands.size()}); + } else if (is_var_handle) { + output_only_resources.push_back( + {return_operands.size(), + var_handle_shared_names[index - var_handles_start_idx]}); + } + return_operands.push_back(resource.live_value); + result_types.push_back(resource.live_value.getType()); + } + + if (is_var_handle && !resource.read) { + assert(block.getArgument(index).getUses().empty()); + argument_indices_to_remove.push_back(index); + } else { + if (is_var_handle) { + // Add resource_name attribute to VarHandleOp read. + function.setArgAttr( + new_argument_index, kResourceNameArgAttr, + builder.getStringAttr( + var_handle_shared_names[index - var_handles_start_idx])); + } + ++new_argument_index; + } + } + + // Remove unread var handle arguments. + for (int argument_index_to_remove : + llvm::reverse(argument_indices_to_remove)) { + block.eraseArgument(argument_index_to_remove); + argument_types.erase(argument_types.begin() + argument_index_to_remove); + } + + // Rewrite return if there are variable writes. + if (return_operands.size() > num_results_before) { builder.create(return_op.getLoc(), return_operands); return_op.erase(); } @@ -274,17 +351,10 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { // Update function argument and result types with new resource subtypes. function.setType(builder.getFunctionType(argument_types, result_types)); - // Add resource_name attribute to the input argument for the resources. - for (auto& resource : resource_map) { - if (auto attr = resource.getFirst().dyn_cast()) { - int64_t input_index = resource.getSecond().input_index; - if (input_index != ResourceInfo::kInputUnassigned) - function.setArgAttr(input_index, "tf.resource_name", attr); - } - } // Add resource_name attribute to the output for the resources. for (auto& resource : output_only_resources) - function.setResultAttr(resource.first, "tf.resource_name", resource.second); + function.setResultAttr(resource.first, kResourceNameArgAttr, + builder.getStringAttr(resource.second)); // Add aliasing_output attribute to the input argument for the resources that // are updated by the function. @@ -309,26 +379,60 @@ void PromoteResourcesToArgsPass::runOnOperation() { // This routine should only be called when control flow operations are still // represented with TF IfOp and WhileOp operations. In this case, there should // be only one basic blocks in the MLIR representation. - if (!hasSingleElement(main_func.getBlocks())) { - main_func.emitError() << "expects 'main' function to have 1 block, got " - << main_func.getBlocks().size(); - return signalPassFailure(); - } + if (failed(CheckSingleBlockFunction(main_func))) return signalPassFailure(); + llvm::SmallVector var_handle_shared_names; if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) || - failed(PromoteResourcesToArguments(main_func))) + failed(PromoteVarHandlesToArguments(main_func, /*add_validation=*/true, + &var_handle_shared_names)) || + failed(PromoteResourcesToArguments(main_func, var_handle_shared_names))) return signalPassFailure(); } +class PromoteVarHandlesToArgsPass + : public PassWrapper> { + public: + void runOnOperation() override; +}; + +void PromoteVarHandlesToArgsPass::runOnOperation() { + ModuleOp module = getOperation(); + MLIRContext* context = module.getContext(); + for (auto function : module.getOps()) { + if (failed(CheckSingleBlockFunction(function))) return signalPassFailure(); + + llvm::SmallVector var_handle_shared_names; + PromoteVarHandlesToArguments(function, /*add_validation=*/false, + &var_handle_shared_names); + + // Add resource names for each `tf.VarHandleOp` that were promoted to + // resource arguments. + const int var_handle_args_offset = + function.getNumArguments() - var_handle_shared_names.size(); + for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) + function.setArgAttr(var_name_and_index.index() + var_handle_args_offset, + kResourceNameArgAttr, + StringAttr::get(var_name_and_index.value(), context)); + } +} + } // namespace std::unique_ptr> CreatePromoteResourcesToArgsPass() { return std::make_unique(); } +std::unique_ptr> CreatePromoteVarHandlesToArgsPass() { + return std::make_unique(); +} + static PassRegistration pass( "tf-promote-resources-to-args", "Promote resources reads/writes to function inputs/outputs."); +static PassRegistration var_handle_pass( + "tf-promote-var-handles-to-args", + "Promote tf.VarHandleOps to function arguments."); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc new file mode 100644 index 00000000000..a80b84ddeda --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TF { +namespace { + +// Location attribute. +constexpr StringRef kClassAttr = "_class"; +constexpr StringRef kLocationPrefix = "loc:@"; + +// A pass that converts readonly reference variables to the corresponding +// resource variables. +// +// It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable). +// +// For the background, this pass is a part of hoisting VariableV2 ops by +// re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which +// can be done by the following passes: +// - Capturing resource values into global tensors (importing saved model). +// - Promoting VarHandle ops to function input/outputs. +// - Freezing global tensor pass. +// +// This path assumes that all the VariableV2 ops is read-only via verifying the +// heuristic method that assumes that all the users of them is Identity op, +// fed directly. +class ConvertReadonlyReferenceVariablesToResourceVariablesPass + : public PassWrapper< + ConvertReadonlyReferenceVariablesToResourceVariablesPass, + FunctionPass> { + public: + void runOnFunction() override; +}; + +// Parse node name from "_class" attribute. +StringRef GetNodeNameFromClassAttr(Operation *op) { + ArrayAttr classes_attr = op->getAttrOfType(kClassAttr); + if (!classes_attr) { + op->emitOpError() << "has no '_class' attribute"; + return StringRef(); + } + + StringRef result; + for (Attribute class_attr : classes_attr) { + StringRef node_name = class_attr.cast().getValue(); + if (!node_name.startswith(kLocationPrefix)) { + continue; + } + if (!result.empty()) { + // Invalid case since there are multiple loc:@ attributes. + op->emitOpError() + << "expects only one named location in '_class' attribute, but got " + << classes_attr; + return StringRef(); + } + result = node_name.drop_front(kLocationPrefix.size()); + } + if (result.empty()) { + op->emitOpError() << "expects variable name in '_class' attribute, but got " + << classes_attr; + } + return result; +} + +void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() { + FuncOp func = getFunction(); + + OpBuilder builder(func.getContext()); + SmallVector variable_v2s_to_replace; + + // Checks all the VariableV2 ops is read-only via verifying the heuristic + // method that assumes that all the users of them is Identity op, feeded + // directly. + auto read_only_vars_fn = [&variable_v2s_to_replace]( + VariableV2Op variable_v2_op) { + if (variable_v2_op.getResult().use_empty()) { + // Erase the op when there is no user. + variable_v2_op.erase(); + return mlir::WalkResult::advance(); + } + if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op]( + Operation *user) { + if (!isa(user)) { + variable_v2_op.emitOpError() + << "expects all users to be 'tf.Identity', but got user " + << user->getName(); + return false; + } + return true; + })) { + return mlir::WalkResult::interrupt(); + } + variable_v2s_to_replace.push_back(variable_v2_op); + return mlir::WalkResult::advance(); + }; + + WalkResult walk_res = func.walk(read_only_vars_fn); + if (walk_res.wasInterrupted()) return signalPassFailure(); + + for (VariableV2Op variable_v2_op : variable_v2s_to_replace) { + builder.setInsertionPoint(variable_v2_op); + ShapedType shaped_type = + variable_v2_op.getResult().getType().cast(); + TensorType tensor_type = DropRefType(shaped_type).cast(); + StringAttr device_attr = variable_v2_op.getAttrOfType("device"); + if (!device_attr) device_attr = builder.getStringAttr(""); + StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op); + if (variable_name.empty()) { + return signalPassFailure(); + } + VarHandleOp var_handle_op = builder.create( + variable_v2_op.getLoc(), + ArrayRef{RankedTensorType::get( + {}, TF::ResourceType::get(ArrayRef{tensor_type}, + builder.getContext()))}, + ArrayRef{}, + ArrayRef{ + builder.getNamedAttr("device", device_attr), + builder.getNamedAttr("container", variable_v2_op.containerAttr()), + builder.getNamedAttr("shared_name", + builder.getStringAttr(variable_name))}); + for (Operation *user : + make_early_inc_range(variable_v2_op.getResult().getUsers())) { + builder.setInsertionPoint(user); + ReadVariableOp read_variable_op = builder.create( + user->getLoc(), ArrayRef{tensor_type}, + ArrayRef{var_handle_op}, ArrayRef{}); + user->getResult(0).replaceAllUsesWith(read_variable_op.getResult()); + user->erase(); + } + variable_v2_op.erase(); + } +} + +} // namespace + +std::unique_ptr> +CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() { + return std::make_unique< + ConvertReadonlyReferenceVariablesToResourceVariablesPass>(); +} + +static PassRegistration< + ConvertReadonlyReferenceVariablesToResourceVariablesPass> + pass("readonly-references-to-resources", + "Convert readonly reference variables to resource variables."); + +} // namespace TF + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index fe9283d6932..15eb5593651 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -37,18 +37,37 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/logging.h" namespace mlir { namespace TFDevice { namespace { constexpr char kDeviceAttr[] = "device"; +constexpr char kReplicaIdAttr[] = "_xla_replica_id"; struct ReplicateToIslandPass : public PassWrapper { void runOnFunction() override; }; +// Returns whether op requires `_xla_replica_id` attribute. +bool RequiresReplicaIDAttribute(Operation* op) { + return llvm::isa(op) || + llvm::isa(op); +} + +// Adds integer attribute that represents replica id for replicated ops that +// require replica id attribute. +void AddReplicaIdToOpsInReplicatedRegion(OpBuilder* builder, Region* region, + const int replica_id) { + region->walk([&](Operation* replicated_op) { + if (RequiresReplicaIDAttribute(replicated_op)) + replicated_op->setAttr(kReplicaIdAttr, + builder->getI32IntegerAttr(replica_id)); + }); +} + // Creates islands per replica from `tf_device.replicate` region. If for a // `tf_device.launch` op the device is an aliased device of the // `tf_device.replicate`, the device will be remapped to an explicit device @@ -90,6 +109,14 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // Copy over replicate region into replica island. replicate_op.body().cloneInto(&replica.body(), mapping); + // TODO(b/157624749): Replace this with better abstraction to + // differentiate ops for different replicas. + // Some ops, such as XlaHostCompute op or TPU Embedding ops, require + // replica id to be added as an op attribute to be used during + // execution. Handle such ops separately and add an integer attribute + // that represents replica id. + AddReplicaIdToOpsInReplicatedRegion(builder, &replica.body(), i); + // Map aliased devices to explicit devices based on replica. if (has_devices) { replica.walk([&](tf_device::LaunchOp launch) { @@ -156,9 +183,9 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // }) {device = "/DEVICE:3"} : () -> tensor // tf_executor.yield %a1, %b1 : tensor, tensor // } -LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, - tf_executor::IslandOp island_op, - tf_device::ReplicateOp replicate_op) { +void CreateIslandsFromReplicate(const Dialect* tf_dialect, + tf_executor::IslandOp island_op, + tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); const int num_replicas = replicate_op.n().getLimitedValue(); @@ -199,21 +226,17 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, } island_op.erase(); - return success(); } // Finds islands with a single `tf_device.replicate` and create individual // islands per replica of the replicate. -LogicalResult LowerSingleIslandReplicateToIslands( - const Dialect* tf_dialect, tf_executor::IslandOp island_op) { - if (!hasSingleElement(island_op.GetBody().without_terminator())) - return success(); +void LowerSingleIslandReplicateToIslands(const Dialect* tf_dialect, + tf_executor::IslandOp island_op) { + if (!island_op.WrapsSingleOp()) return; if (auto replicate_op = llvm::dyn_cast(&island_op.GetBody().front())) - return CreateIslandsFromReplicate(tf_dialect, island_op, replicate_op); - - return success(); + CreateIslandsFromReplicate(tf_dialect, island_op, replicate_op); } void ReplicateToIslandPass::runOnFunction() { @@ -223,13 +246,9 @@ void ReplicateToIslandPass::runOnFunction() { getFunction().emitError() << "'tf' dialect is not registered"; } - auto result = getFunction().walk([&](tf_executor::IslandOp island_op) { - if (failed(LowerSingleIslandReplicateToIslands(tf_dialect, island_op))) - return WalkResult::interrupt(); - return WalkResult::advance(); + getFunction().walk([&](tf_executor::IslandOp island_op) { + LowerSingleIslandReplicateToIslands(tf_dialect, island_op); }); - - if (result.wasInterrupted()) return signalPassFailure(); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index d37dfd14590..21d74d81b20 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, } auto walk_res = func_op.walk([&](Operation* op) { if (auto var_handle = llvm::dyn_cast(op)) { - // Record VarHanldeOp's device attribute. + // Record VarHandleOp's device attribute. auto device_attr = var_handle.getAttrOfType(kDeviceAttr); if (!device_attr || device_attr.getValue().empty()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 611c4d2725a..ed7ebc25c9f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp( } // Lifts loads/stores from while loop's body and cond functions. -LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { +LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { // Remove identity nodes to avoid aliasing. RemoveIdentity(&body.front()); RemoveIdentity(&cond.front()); @@ -668,72 +668,74 @@ LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { return success(); } -// Lifts loads/stores from an IfOp's branches. -LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, - FuncOp else_branch) { +// Lifts loads/stores from an IfOp or CaseOp's branches. +template +LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { // Remove identity nodes to avoid aliasing. - RemoveIdentity(&then_branch.front()); - RemoveIdentity(&else_branch.front()); + for (auto func : branches) RemoveIdentity(&func.front()); + // Sanity check: branch return of resources should be aliases of inputs. If // so, replace the output uses with the input so that we can remove these // outputs. - for (auto entry : llvm::enumerate( - llvm::zip(then_branch.front().getTerminator()->getOperands(), - else_branch.front().getTerminator()->getOperands()))) { - auto then_retval = std::get<0>(entry.value()); - auto else_retval = std::get<1>(entry.value()); - assert(then_retval.getType() == else_retval.getType()); - if (!getElementTypeOrSelf(then_retval.getType()).isa()) { + for (OpResult result : op.getResults()) { + if (!getElementTypeOrSelf(result.getType()).isa()) continue; + unsigned result_index = result.getResultNumber(); + constexpr unsigned kUnassigned = -1; + unsigned common_aliasing_arg_num = kUnassigned; + for (auto func : branches) { + auto retval = func.front().getTerminator()->getOperand(result_index); + assert(result.getType() == retval.getType()); + auto aliasing_arg = retval.dyn_cast(); + if (common_aliasing_arg_num == kUnassigned) + common_aliasing_arg_num = aliasing_arg.getArgNumber(); + if (!aliasing_arg || + aliasing_arg.getArgNumber() != common_aliasing_arg_num) + return op.emitOpError("unsupported output: ") + << "resource does not alias a single input"; } - auto then_aliasing_arg = then_retval.dyn_cast(); - auto else_aliasing_arg = else_retval.dyn_cast(); - if (!then_aliasing_arg || !else_aliasing_arg || - then_aliasing_arg.getArgNumber() != else_aliasing_arg.getArgNumber()) { - return if_op.emitOpError("unsupported tf.IfOp output: ") - << "resource does not alias a single input."; - } - if_op.getResult(entry.index()) - .replaceAllUsesWith( - if_op.getOperand(then_aliasing_arg.getArgNumber() + 1)); + assert(common_aliasing_arg_num != kUnassigned); + result.replaceAllUsesWith(op.getOperand(common_aliasing_arg_num + 1)); } + // Erase the resource outputs from the branches. int64_t non_resource_results = 0; llvm::SmallVector old_to_new_output_indices; - llvm::SmallVector new_output_shapes; bool output_removed = false; - for (auto result : if_op.getResults()) { - if (!getElementTypeOrSelf(result.getType()).isa()) { + for (auto result : op.getResults()) { + if (!getElementTypeOrSelf(result.getType()) + .template isa()) { old_to_new_output_indices.push_back(non_resource_results++); - if (!if_op.output_shapes().getValue().empty()) { - new_output_shapes.push_back( - if_op.output_shapes().getValue()[result.getResultNumber()]); - } continue; } old_to_new_output_indices.push_back(-1); - then_branch.front().getTerminator()->eraseOperand(non_resource_results); - else_branch.front().getTerminator()->eraseOperand(non_resource_results); + for (auto func : branches) + func.front().getTerminator()->eraseOperand(non_resource_results); output_removed = true; } - llvm::SmallDenseMap then_use_info; - llvm::SmallDenseMap else_use_info; - if (failed(FindResourceArgUseInfo(then_branch, &then_use_info)) || - failed(FindResourceArgUseInfo(else_branch, &else_use_info))) { + llvm::SmallDenseMap resource_arg_uses; + if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses))) return failure(); + + for (auto func : branches.drop_front()) { + llvm::SmallDenseMap branch_use_info; + if (failed(FindResourceArgUseInfo(func, &branch_use_info))) + return failure(); + // A resource is considered used as long as it is used in either branch. + resource_arg_uses = + MergeArgResourceUseInfo(resource_arg_uses, branch_use_info); } - // A resource is considered used as long as it is used in either branch. - auto resource_arg_uses = - MergeArgResourceUseInfo(then_use_info, else_use_info); + if (resource_arg_uses.empty() && !output_removed) return success(); // Remove unused resources in functions. llvm::SmallDenseMap remaining_resource_data_types; RemoveUnusedResourceArgumentsAndForwardedRetvals( - resource_arg_uses, then_branch, /*old_to_new_arg_indices=*/nullptr, + resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr, &remaining_resource_data_types); - RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, - else_branch); + for (auto func : branches.drop_front()) + RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func); + // Forward resource inputs updated in any branch to the outputs of both // branches. First prepare the mapping from arg to new update output. llvm::SmallDenseMap resource_arg_to_new_output; @@ -751,10 +753,11 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, new_output_index; } } + // Append resource updates to the return ops: now they are just forwarded // input resources, but will be replaced by the data value in // LiftArgRetResourcesForFunction(). - for (auto branch : {then_branch, else_branch}) { + for (auto branch : branches) { auto new_retvals = llvm::to_vector<4>(branch.front().getTerminator()->getOperands()); for (const auto& entry : resource_arg_to_new_output) { @@ -771,18 +774,18 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, }); } - // Recreate the if op. - OpBuilder builder(if_op); + // Recreate the op without resource operands. + OpBuilder builder(op); // Now use the filtered original operands, which will be replaced by // AddLoadsStoresOutsideControlFlowOp(). auto new_operands = - FilterRange(if_op.input(), resource_arg_uses); - new_operands.insert(new_operands.begin(), if_op.cond()); - auto new_if = builder.create(if_op.getLoc(), - then_branch.getType().getResults(), - new_operands, if_op.getAttrs()); - // Prepare for AddLoadsStoresOutsideControlFlowOp() and update - // new_output_shapes. + FilterRange(op.input(), resource_arg_uses); + new_operands.insert(new_operands.begin(), op.getOperand(0)); + FuncOp first_func = branches.front(); + auto new_op = + builder.create(op.getLoc(), first_func.getType().getResults(), + new_operands, op.getAttrs()); + // Prepare for AddLoadsStoresOutsideControlFlowOp() llvm::SmallDenseMap> arg_data_type_and_updated_output_index; for (const auto& entry : remaining_resource_data_types) { @@ -792,22 +795,17 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch, : new_output_it->getSecond(); arg_data_type_and_updated_output_index[entry.getFirst() + 1] = { entry.getSecond(), update_index}; - if (!if_op.output_shapes().getValue().empty() && update_index >= 0) { - new_output_shapes.push_back( - tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond())); - } } - AddLoadsStoresOutsideControlFlowOp(new_if, + AddLoadsStoresOutsideControlFlowOp(new_op, arg_data_type_and_updated_output_index); - new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); // Replace uses. for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) { if (old_to_new_output_indices[i] >= 0) { - if_op.getResult(i).replaceAllUsesWith( - new_if.getResult(old_to_new_output_indices[i])); + op.getResult(i).replaceAllUsesWith( + new_op.getResult(old_to_new_output_indices[i])); } } - if_op.erase(); + op.erase(); return success(); } @@ -985,7 +983,7 @@ LogicalResult HoistForFunctionalControlFlow( lifted_partitioned_call_callees); HoistForFunctionalControlFlow(&cond.front(), module, lifted_partitioned_call_callees); - if (failed(HanldeWhileLoop(while_op, body, cond))) return failure(); + if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { auto then_branch = llvm::cast(module.lookupSymbol(if_op.then_branch())); @@ -996,7 +994,20 @@ LogicalResult HoistForFunctionalControlFlow( lifted_partitioned_call_callees); HoistForFunctionalControlFlow(&else_branch.front(), module, lifted_partitioned_call_callees); - if (failed(HandleIfOP(if_op, then_branch, else_branch))) return failure(); + if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch}))) + return failure(); + } else if (auto case_op = llvm::dyn_cast(&op)) { + SmallVector branch_functions; + branch_functions.reserve(case_op.branches().size()); + for (const Attribute& branch : case_op.branches()) { + FuncOp func = + module.lookupSymbol(branch.cast()); + // Recursively handle the nested control flow. + HoistForFunctionalControlFlow(&func.front(), module, + lifted_partitioned_call_callees); + branch_functions.push_back(func); + } + if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure(); } else if (auto call_op = llvm::dyn_cast(&op)) { if (!call_op.f().isa()) { return call_op.emitOpError( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 789088bd585..1e9be76aa66 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -66,8 +66,7 @@ using tensorflow::shape_inference::ShapeHandle; namespace mlir { namespace TF { namespace { -Optional> InferShapeForFunctionReturnType( - FuncOp func) { +Optional> InferShapeForFunctionReturnType(FuncOp func) { // Find any return ops. SmallVector return_ops; for (Block& block : func) { @@ -137,9 +136,9 @@ void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result, cast_op = b.create(op->getLoc(), old_type, result, /*truncate=*/b.getBoolAttr(false)); } - return mlir::Value(cast_op); + return Value(cast_op); }; - for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) { + for (OpOperand& use : make_early_inc_range(result.getUses())) { if (use.getOwner()->getDialect() != tf_dialect && !IsSupportedNonTFOp(use.getOwner())) use.set(get_cast_op()); @@ -162,7 +161,7 @@ Optional GetShapeFromMlirType(Type t) { bool InferShapeForPassThroughOps(OperandRange pass_through_operands, Operation* op, Dialect* tf_dialect) { bool changed = false; - for (auto entry : llvm::zip(pass_through_operands, op->getResults())) { + for (auto entry : zip(pass_through_operands, op->getResults())) { Type operand_type = std::get<0>(entry).getType(); Value result = std::get<1>(entry); if (result.getType() == operand_type) continue; @@ -204,7 +203,7 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) { tf_dialect); } // TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp. - if (auto tensor_cast = dyn_cast(op)) { + if (auto tensor_cast = dyn_cast(op)) { return InferShapeForPassThroughOps( tensor_cast.getOperation()->getOperands(), op, tf_dialect); } @@ -254,7 +253,7 @@ GetSubtypes(Type type) { // match the i-th operand type). Returns true if anything is changed. bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { bool changed = false; - for (auto entry : llvm::zip(operands, results)) { + for (auto entry : zip(operands, results)) { Type operand_type = std::get<0>(entry).getType(); Type result_type = std::get<1>(entry).getType(); if (operand_type == result_type) continue; @@ -291,14 +290,13 @@ bool InferShapeForCall(Operation* op) { CallInterfaceCallable callable = call_op.getCallableForCallee(); SymbolRefAttr sym = callable.dyn_cast(); if (!sym) return false; - FuncOp func = - dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); + FuncOp func = dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); if (!func) return false; bool changed = false; // Map each of the results of the call to the returned type of the // function. - for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) { + for (auto result : zip(op->getResults(), func.getType().getResults())) { if (std::get<0>(result).getType() == std::get<1>(result)) continue; // Skip already statically shaped results. if (!CanBeRefined(std::get<0>(result).getType())) continue; @@ -323,8 +321,8 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, Operation* op = infer_ti.getOperation(); SmallVector inferred; LogicalResult res = infer_ti.inferReturnTypes( - op->getContext(), op->getLoc(), op->getOperands(), op->getAttrs(), - op->getRegions(), inferred); + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), inferred); if (failed(res)) { op->emitOpError("failed to refine type as inference failed"); return false; @@ -335,7 +333,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, // Map each of the results of the call to the returned type of the // function. bool changed = false; - for (auto result : llvm::zip(op->getResults(), inferred)) { + for (auto result : zip(op->getResults(), inferred)) { if (std::get<0>(result).getType() == std::get<1>(result)) continue; // Inserts a cast back to the original type if any user is not in the @@ -356,7 +354,7 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output // scalar value). struct ValuePort { - llvm::PointerUnion producer; + PointerUnion producer; SmallVector port; bool operator==(const ValuePort& other) const { @@ -374,39 +372,38 @@ struct ValuePort { port = {0}; } } - ValuePort(llvm::PointerUnion producer, + ValuePort(PointerUnion producer, SmallVector port) : producer(producer), port(port) {} - llvm::raw_ostream& print(llvm::raw_ostream& os) const { + raw_ostream& print(raw_ostream& os) const { if (auto* op = producer.dyn_cast()) os << "op " << op->getName(); if (auto ba = producer.dyn_cast()) os << "block_arg " << ba.getArgNumber(); - os << llvm::formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); + os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); return os; } }; struct ValuePortHasher { std::size_t operator()(const ValuePort& other) const { - return llvm::hash_combine( - llvm::hash_value(other.producer.getOpaqueValue()), - llvm::hash_value(ArrayRef(other.port))); + return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()), + hash_value(ArrayRef(other.port))); } }; using ValuePortResultMap = std::unordered_map; -using ComputedQueryFn = llvm::function_ref; -using ValueQueryFn = llvm::function_ref; -using ValuePortInputs = llvm::SmallVectorImpl; +using ComputedQueryFn = function_ref; +using ValueQueryFn = function_ref; +using ValuePortInputs = SmallVectorImpl; -// TODO(jpienaar): InputsRequiredForOutput and ComputeOutputComponent are +// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are // intended to be switched to op interfaces once more refined. -LogicalResult InputsRequiredForOutput(ValuePort value_port, - ComputedQueryFn has_been_computed, - ValuePortInputs* inputs) { +LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, + ComputedQueryFn has_been_computed, + ValuePortInputs* inputs) { auto op = value_port.producer.dyn_cast(); auto& port = value_port.port; if (!op) return failure(); @@ -432,7 +429,8 @@ LogicalResult InputsRequiredForOutput(ValuePort value_port, // existing computed values. Attribute ComputeOutputComponent(const ValuePort& value_port, ValueQueryFn values) { - LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for ")); + LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n"); + if (auto known = values(value_port)) return known; auto op = value_port.producer.dyn_cast(); if (!op) return nullptr; @@ -457,29 +455,140 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, ValuePort op_port(op->getOperand(port[1])); return values(op_port); } + + if (auto graph = dyn_cast(op)) { + if (port.size() == 1) + return ComputeOutputComponent( + ValuePort(graph.GetFetch().fetches()[port[0]]), values); + return nullptr; + } + + if (auto island = dyn_cast(op)) { + if (port.size() == 1) + return ComputeOutputComponent( + ValuePort(island.GetYield().fetches()[port[0]]), values); + return nullptr; + } + return nullptr; } -ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { +// Context used during ShapeInference. This class contains common information +// that is required by the individual shape inference helper functions (e.g., +// TF Graph version, constant values computed, etc.) +class ShapeInference { + public: + ShapeInference(int64_t graph_version, MLIRContext* context, + bool propagate_caller_callee_constants); + + LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, + ValuePortInputs* inputs) { + return ::mlir::TF::ComputeInputsRequiredForOutput( + value_port, + [this](const ValuePort& port) { + return results_.find(port) != results_.end(); + }, + inputs); + } + + Attribute ComputeOutputComponent(const ValuePort& value_port) { + if (auto known_attr = results_[value_port]) return known_attr; + auto attr = ::mlir::TF::ComputeOutputComponent( + value_port, [this](const ValuePort& port) { return results_[port]; }); + RecordValue(value_port, attr); + return attr; + } + + // Returns ShapeHandle if the op result could be computed as shape. + ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic); + + void RecordValue(const ValuePort& value_port, Attribute value) { + LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ") + << value << "\n"); + results_[value_port] = value; + } + + // Performs shape inference on the provided op and return true if the type of + // at least one result has been changed. + // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect. + // `graph_version` indicates the current GraphDef compatibility versions + // (the versions field in graph.proto). + bool InferShapeForSingleOperation(Operation* op); + + // Infers shape on the provided region, including nested ones, iterate until + // fix point with a limit of max_iteration. Returns success if fix point is + // reached before max_iteration. + LogicalResult InferShapeUntilFixPoint(Region* region, + int64_t max_iteration = 10); + + // Updates input types and refine shapes inside body of functions that are + // attached to ControlFlow ops (If/While). These functions include Then/Else + // branches of IfOp and Cond/Body functions of WhileOp. These functions share + // following common properties: + // 1) They are never reused, ie. having a single use in module. + // 2) Their input types match those of their parent ops (excluding inputs + // like predicate). + // Returns a boolean indicating whether any change has been applied. + LogicalResult RefineShapeForControlFlowFunc(FuncOp func, + ArrayRef input_types, + int64_t max_iteration); + + // Propagate the shapes to the functions named. + LogicalResult PropagateShapeToFunctions( + ModuleOp module, Operation::operand_type_range input_types, + ArrayRef func_names, int64_t max_iteration); + + // Shape propagation for call/control flow ops. + LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, + int64_t max_iteration); + + // Propagates any constant operand of call_op to the called function body's + // corresponding argument if the callee has only one use. + // + // TODO(b/154065712): Move this to a more general inter-procedural constant + // folding pass. + void PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module); + + // Propagates any constant return value of the callee function to the call + // op's corresponding result. + void PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module); + + // Tries to compute the result of folding the op. This doesn't actually + // perform constant folding, it is just computes the equivalent constants. + // Returns whether it was able to compute constant values. + LogicalResult TryToFold(Operation* op); + + private: + // Mapping between ValuePort (which corresponds to an OpResult or smaller, + // e.g., first element of OpResult produced) to an Attribute if the ValuePort + // corresponds to a constant value. + ValuePortResultMap results_; + int64_t graph_version_; + Dialect* tf_dialect_; + + // TODO(b/154065712): Remove propagate_caller_callee_constants once using + // SCCP pass instead. + bool propagate_caller_callee_constants_; +}; + +ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context, + bool propagate_caller_callee_constants) + : graph_version_(graph_version), + propagate_caller_callee_constants_(propagate_caller_callee_constants) { + tf_dialect_ = context->getRegisteredDialect(); +} + +ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, + InferenceContext* ic) { LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially ")); auto rt = result.getType().dyn_cast(); if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {}; int dim_size = rt.getDimSize(0); // Worklist to direct partial evaluation. - llvm::SmallVector worklist; - // The ValuePort evaluated results. - // TODO(jpienaar): This could be cached across invocations (e.g., part of some - // inference context). - ValuePortResultMap evaluated; - // Returns whether a ValuePort has been previously computed. - auto has_been_computed = [&evaluated](const ValuePort& port) { - return evaluated.find(port) != evaluated.end(); - }; - // Returns previously computed ValuePort value. - auto values = [&evaluated](const ValuePort& port) -> Attribute { - return evaluated[port]; - }; + SmallVector worklist; // Simple evaluator that attempts to partially evaluate the input value even // if unable to evaluate the complete output. Below follows a simple stack @@ -498,7 +607,7 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front ")); SmallVector inputs; - auto res = InputsRequiredForOutput(front, has_been_computed, &inputs); + auto res = ComputeInputsRequiredForOutput(front, &inputs); if (failed(res)) { // Abort if unable to find which required inputs need to be computed. worklist.clear(); @@ -513,16 +622,15 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { continue; } - auto ret = ComputeOutputComponent(front, values); + auto ret = ComputeOutputComponent(front); if (!ret) continue; - evaluated[front] = ret; LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); // If worklist is empty, then this is the root query op. if (worklist.empty()) { LLVM_DEBUG(llvm::dbgs() << "[root node]\n"); - if (auto dea = ret.dyn_cast()) { + if (auto dea = ret.dyn_cast()) { if (dea.getNumElements() != 1) { LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n"); return {}; @@ -536,9 +644,10 @@ ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic) { return ic->MakeShape(dims); } -bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, - int64_t graph_version) { - assert(tf_dialect == op->getDialect()); +bool ShapeInference::InferShapeForSingleOperation(Operation* op) { + LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for "); + llvm::dbgs() << "\n"); + assert(tf_dialect_ == op->getDialect()); // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough // to make sure they are preserved in the output. @@ -550,7 +659,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // If no result for this op needs shape inference, we have a fast-path return. // But if the type is a resource/variant, we do not skip it because we might // not have the handle shapes. - if (llvm::none_of(op->getResultTypes(), CanBeRefined)) { + if (none_of(op->getResultTypes(), CanBeRefined)) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" << op->getName() << "'.\n"); return false; @@ -565,8 +674,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // This is necessary to avoid reprocessing the tf.Cast that are inserted at // the end of this function. if (isa(op) && - llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) { - return user->getDialect() != tf_dialect; + all_of(op->getResult(0).getUsers(), [&](Operation* user) { + return user->getDialect() != tf_dialect_; })) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF " "dialect operation users '" @@ -622,10 +731,14 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, size_t index = it.index(); // If the operand is constant, then convert it to Tensor. - ElementsAttr attr; - if (matchPattern(operand, m_Constant(&attr))) { + ValuePort vp(operand); + Attribute attr = ComputeOutputComponent(vp); + if (!attr && matchPattern(operand, m_Constant(&attr))) + RecordValue(vp, attr); + if (attr) { tensorflow::Tensor* input_tensor = &tensors[index]; - auto status = tensorflow::ConvertToTensor(attr, input_tensor); + auto status = + tensorflow::ConvertToTensor(attr.cast(), input_tensor); if (status.ok()) { input_tensors[index] = input_tensor; } else { @@ -646,7 +759,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // Perform the shape inference using an InferenceContext with the input // shapes. This object is abstracting the information that the ShapeInference // function operates on. - InferenceContext c(graph_version, *node_def, op_reg_data->op_def, + InferenceContext c(graph_version_, *node_def, op_reg_data->op_def, input_shapes, input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); auto status = c.Run(op_reg_data->shape_inference_fn); @@ -659,15 +772,17 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // Determine if, during shape computation, the shape functions attempted to // query an input operand as shape where the input was not known/constant. bool requires_inputs = - llvm::any_of(llvm::seq(0, c.num_inputs()), [&](int input) { + any_of(llvm::seq(0, c.num_inputs()), [&](int input) { return c.requested_input_tensor_as_partial_shape(input) && !input_tensors[input]; }); if (requires_inputs) { + LLVM_DEBUG(llvm::dbgs() << "\trequired input\n"); std::vector input_tensors_as_shapes; for (int input : llvm::seq(0, c.num_inputs())) { if (c.requested_input_tensor_as_partial_shape(input) && !input_tensors[input]) { + LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n"); auto op_result = op->getOperand(input).dyn_cast(); if (!op_result) continue; // Resize on first valid shape computed. @@ -723,7 +838,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, new_element_type.isa()) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { - llvm::SmallVector subtypes; + SmallVector subtypes; OpBuilder b(op); for (const auto& shape_n_type : *handle_shapes_types) { Type element_type; @@ -743,7 +858,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, if (result.getType() == new_type) continue; // Inserts a cast back to the original type if any user is not in the TF // dialect. - AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, + AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_, result.getType()); // Finally we inferred the shape and replace the type for this result. result.setType(new_type); @@ -755,23 +870,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, return changed; } -// Updates input types and refine shapes inside body of functions that are -// attached to ControlFlow ops (If/While). These functions include Then/Else -// branches of IfOp and Cond/Body functions of WhileOp. These functions share -// following common properties: -// 1) They are never reused, ie. having a single use in module. -// 2) Their input types match those of their parent ops (excluding inputs like -// predicate). -// Returns a boolean indicating whether any change has been applied. -LogicalResult RefineShapeForControlFlowFunc(FuncOp func, - llvm::ArrayRef input_types, - int64_t graph_version, - int64_t max_iteration) { +LogicalResult ShapeInference::RefineShapeForControlFlowFunc( + FuncOp func, ArrayRef input_types, int64_t max_iteration) { ModuleOp module = func.getParentOfType(); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); if (num_uses != 1) { - func.emitWarning(llvm::formatv( + func.emitWarning(formatv( "expected control flow function {0} to have exactly 1 use, found {1}.", func.getName(), num_uses)); return failure(); @@ -785,8 +890,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, arg_and_idx.value().setType(input_types[arg_and_idx.index()]); } - auto res = - InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration); + auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration); if (failed(res)) return res; auto new_return_types = InferShapeForFunctionReturnType(func); @@ -798,85 +902,98 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, return success(); } -LogicalResult PropagateShapeToFunctions( +LogicalResult ShapeInference::PropagateShapeToFunctions( ModuleOp module, Operation::operand_type_range input_types, - llvm::ArrayRef func_names, int64_t graph_version, - int64_t max_iteration) { - bool success = true; + ArrayRef func_names, int64_t max_iteration) { + bool all_succeeded = true; auto types = llvm::to_vector<4>(input_types); for (auto func_name : func_names) { FuncOp func = module.lookupSymbol(func_name); - if (failed(RefineShapeForControlFlowFunc(func, types, graph_version, - max_iteration))) { - success = false; - } + all_succeeded = + succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) && + all_succeeded; } - return mlir::success(success); + return success(all_succeeded); } -// If the callee has only one use, propagates any constant operand of call_op to -// the called function body's corresponding argument. -// -// TODO(b/154065712): Move this to a more general inter-procedural constant -// folding pass. -void PropagateConstantToCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module) { +void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, + ModuleOp module) { auto func = module.lookupSymbol(callee_sym.getRootReference()); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); + if (num_uses != 1) return; + OpBuilder builder(&func.front().front()); Operation* op = call_op.getOperation(); - if (num_uses == 1) { - // If this is the only caller, and an operand is a constant, propagate - // the constant inside the function. - for (auto arg : func.getArguments()) { - auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp(); - if (llvm::isa_and_nonnull(operand)) { - arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0)); + // If this is the only caller, and an operand is a constant, propagate + // the constant value inside the function. + for (auto arg : func.getArguments()) { + auto operand = op->getOperand(arg.getArgNumber()); + if (propagate_caller_callee_constants_) { + if (isa_and_nonnull(operand.getDefiningOp())) { + arg.replaceAllUsesWith( + builder.clone(*operand.getDefiningOp())->getResult(0)); } + continue; } + + auto known_constant = ComputeOutputComponent(ValuePort(operand)); + if (!known_constant) continue; + LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: "); + known_constant.print(llvm::dbgs() << " constant "); + llvm::dbgs() << "\n"); + RecordValue(ValuePort(arg), known_constant); } } -// Propagates any constant return value of the callee function to the call op's -// corresponding result. -void PropagateConstantFromCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module) { +void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, + ModuleOp module) { auto func = module.lookupSymbol(callee_sym.getRootReference()); - // If the return value is a constant, replace the call result with a constant. + // If the return value is a constant, use the constant as the value of + // the call return. Operation* op = call_op.getOperation(); OpBuilder builder(op); builder.setInsertionPointAfter(op); for (auto retval : llvm::enumerate(func.front().getTerminator()->getOperands())) { - auto retval_op = retval.value().getDefiningOp(); - if (llvm::isa_and_nonnull(retval_op)) { - op->getResult(retval.index()) - .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); + if (propagate_caller_callee_constants_) { + auto retval_op = retval.value().getDefiningOp(); + if (isa_and_nonnull(retval_op)) { + op->getResult(retval.index()) + .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); + } + continue; + } + + ValuePort vp(retval.value()); + if (auto known_constant = ComputeOutputComponent(vp)) { + LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant "); + call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n"); + RecordValue(ValuePort(op->getResult(retval.index())), known_constant); } } } -LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, - int64_t graph_version, - int64_t max_iteration) { +LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( + Operation* op, int64_t max_iteration) { ModuleOp module = op->getParentOfType(); if (auto if_op = dyn_cast(op)) { return PropagateShapeToFunctions( - module, llvm::drop_begin(if_op.getOperandTypes(), 1), - {if_op.then_branch(), if_op.else_branch()}, graph_version, - max_iteration); + module, drop_begin(if_op.getOperandTypes(), 1), + {if_op.then_branch(), if_op.else_branch()}, max_iteration); } else if (auto while_op = dyn_cast(op)) { return PropagateShapeToFunctions(module, while_op.getOperandTypes(), {while_op.cond(), while_op.body()}, - graph_version, max_iteration); + max_iteration); } else if (auto call_op = dyn_cast(op)) { CallInterfaceCallable callable = call_op.getCallableForCallee(); if (SymbolRefAttr sym = callable.dyn_cast()) { PropagateConstantToCallee(call_op, sym, module); if (failed(PropagateShapeToFunctions( module, call_op.getArgOperands().getTypes(), - {sym.getRootReference()}, graph_version, max_iteration))) { + {sym.getRootReference()}, max_iteration))) { return failure(); } PropagateConstantFromCallee(call_op, sym, module); @@ -889,13 +1006,71 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, return success(); } -LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, - int64_t max_iteration) { - MLIRContext* ctx = region->getContext(); - Dialect* tf_dialect = ctx->getRegisteredDialect(); +LogicalResult ShapeInference::TryToFold(Operation* op) { + LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n"); + // If any output result is known, then the op probably has been computed + // before. + if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))]) + return success(); - // An operation folder that is used to attempt folding before inference. - OperationFolder folder(ctx); + SmallVector constant_operands(op->getNumOperands()); + SmallVector fold_results; + + // Check to see if any operands to the operation is constant and whether + // the operation knows how to constant fold itself. + bool some_unknown = false; + for (int i = 0, e = op->getNumOperands(); i != e; ++i) { + if (!(constant_operands[i] = + ComputeOutputComponent(ValuePort(op->getOperand(i))))) + some_unknown = true; + } + + // Attempt to constant fold the operation. + auto* abstract_op = op->getAbstractOperation(); + LogicalResult folded = failure(); + if (abstract_op) { + folded = abstract_op->foldHook(op, constant_operands, fold_results); + } + // Attempt dialect fallback if op's fold hook failed. + if (failed(folded)) { + Dialect* dialect = op->getDialect(); + if (!dialect) return failure(); + // Only attempt TF dialect fallback if there are no unknown operands. + if (some_unknown && dialect == tf_dialect_) return failure(); + SmallVector constants; + if (failed(dialect->constantFoldHook(op, constant_operands, constants))) + return failure(); + fold_results.assign(constants.begin(), constants.end()); + } + + for (auto result : zip(op->getResults(), fold_results)) { + auto fold_result = std::get<1>(result); + Attribute attr = nullptr; + if ((attr = fold_result.dyn_cast())) { + RecordValue(ValuePort(std::get<0>(result)), attr); + } else { + auto value = fold_result.get(); + if ((attr = ComputeOutputComponent(ValuePort(value)))) + RecordValue(ValuePort(std::get<0>(result)), attr); + } + + if (ElementsAttr eattr = attr.dyn_cast_or_null()) { + if (std::get<0>(result).getType() == eattr.getType()) continue; + + // Inserts a cast back to the original type if any user is not in the + // TF dialect. + Type old_type = std::get<0>(result).getType(); + std::get<0>(result).setType(eattr.getType()); + AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_, + old_type); + } + } + + return success(); +} + +LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, + int64_t max_iteration) { bool changed = true; // TODO(aminim): we could have a more efficient traversal by guiding the @@ -908,29 +1083,27 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, << "Shape inference, iteration " << iteration << "\n"); region->walk([&](Operation* op) { if (auto infer_ti = dyn_cast(op)) { - changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect); - // TODO(jpienaar): Debug why we can't just return here. We end up with - // additional constant due to the propagation of constant into attached - // function if we return already. - } - - if (op->getDialect() != tf_dialect) { - changed |= InferShapeForNonTFDialectOperation(op, tf_dialect); + changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); return; } - // Before attempting inference, just try to fold the operation. - if (succeeded(folder.tryToFold(op))) return; + if (op->getDialect() != tf_dialect_) { + changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_); + return; + } + + // Before attempting inference, just try to compute the folded + // value/shape. + if (succeeded(TryToFold(op))) return; // Best-effort shape inference in attached functions. Do not return // failure even if it doesn't get to fixed point. - if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version, - max_iteration))) { + if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) { op->emitWarning() << "unable to refine shape of attached function " "arguments and bodies"; } - changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version); + changed |= InferShapeForSingleOperation(op); }); } @@ -944,32 +1117,46 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, - int64_t graph_version) { - mlir::FunctionType func_type = func.getType(); + int64_t graph_version, + bool propagate_caller_callee_constants) { + ShapeInference context(graph_version, func.getContext(), + propagate_caller_callee_constants); + if (arg_shapes.empty()) { + if (failed(context.InferShapeUntilFixPoint(&func.getBody()))) + return failure(); + // TODO(b/156276510): Verify that it is always fine to refine a function's + // return type, as long as we do not change the argument shapes. + if (auto return_types = InferShapeForFunctionReturnType(func)) { + func.setType(FunctionType::get(func.getType().getInputs(), + return_types.getValue(), + func.getContext())); + } + + return success(); + } + FunctionType func_type = func.getType(); bool needs_refinement = false; - llvm::SmallVector new_arg_types; + SmallVector new_arg_types; new_arg_types.reserve(func_type.getNumInputs()); // Update argument types in-place using the provided arg_shapes. for (size_t i = 0; i < func_type.getNumInputs(); ++i) { ArrayRef shape = arg_shapes[i]; - mlir::Type element_type; - if (auto input_ty = - func_type.getInput(i).dyn_cast()) { - if (!input_ty || input_ty.getShape().size() != shape.size()) { + Type element_type; + if (auto input_ty = func_type.getInput(i).dyn_cast()) { + if (input_ty.getRank() != shape.size()) { return failure(); } element_type = input_ty.getElementType(); } else { - auto unranked_input_ty = - func_type.getInput(i).dyn_cast(); + auto unranked_input_ty = func_type.getInput(i).dyn_cast(); if (!unranked_input_ty) { return failure(); } element_type = unranked_input_ty.getElementType(); } - auto new_arg_type = mlir::RankedTensorType::get(shape, element_type); + auto new_arg_type = RankedTensorType::get(shape, element_type); if (new_arg_type != func_type.getInput(i)) { // If the new type is more detailed, trigger shape inference. func.getArgument(i).setType(new_arg_type); @@ -982,28 +1169,17 @@ LogicalResult InferShapeForFunction(FuncOp func, return success(); } - mlir::LogicalResult result = - mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version); + LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody()); if (failed(result)) { return failure(); } auto return_types = InferShapeForFunctionReturnType(func); - func.setType(mlir::FunctionType::get(new_arg_types, - return_types.hasValue() - ? return_types.getValue() - : func.getType().getResults(), - func.getContext())); - - return success(); -} - -LogicalResult InferShapeForFunctionType(FuncOp func) { - if (auto return_types = InferShapeForFunctionReturnType(func)) { - func.setType(mlir::FunctionType::get(func.getType().getInputs(), - return_types.getValue(), - func.getContext())); - } + func.setType(FunctionType::get(new_arg_types, + return_types.hasValue() + ? return_types.getValue() + : func.getType().getResults(), + func.getContext())); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 0524ec678ed..7486fd77388 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -27,29 +27,14 @@ namespace mlir { namespace TF { -// Performs shape inference on the provided op and return true if the type of -// at least one result has been changed. -// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect. -// `graph_version` indicates the current GraphDef compatibility versions -// (the versions field in graph.proto). -bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, - int64_t graph_version); - -// Infers shape on the provided region, including nested ones, iterate until fix -// point with a limit of max_iteration. Returns success if fix point is reached -// before max_iteration. -LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, - int64_t max_iteration = 10); - // Given a list of refined shapes matching the function arguments of func, runs // shape inference over the function to propagate this updated information. -LogicalResult InferShapeForFunction(FuncOp func, - ArrayRef> arg_shapes, - int64_t graph_version); - -// Refines the return type of the given function by folding tf.Cast that -// precedes the return instruction. -LogicalResult InferShapeForFunctionType(FuncOp func); +// If arg_shapes are empty, then argument shapes will be left unchanged. +// TODO(b/154065712): Remove propagate_caller_callee_constants once using +// SCCP pass instead. +LogicalResult InferShapeForFunction( + FuncOp func, ArrayRef> arg_shapes, int64_t graph_version, + bool propagate_caller_callee_constants = true); } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 48e4e77ce0f..1a846398412 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -47,8 +47,15 @@ namespace { // This transformation pass propagate shapes on the TensorFlow graph. // It is a ModulePass in order to be able to change function types. -struct ShapeInference +class ShapeInference : public PassWrapper> { + public: + ShapeInference() = default; + ShapeInference(const ShapeInference& that) { + propagate_caller_callee_constants_ = + that.propagate_caller_callee_constants_; + } + void runOnOperation() override { auto module = getOperation(); auto producer_or = tensorflow::GetTfGraphProducerVersion(module); @@ -58,12 +65,17 @@ struct ShapeInference } int64_t producer = producer_or.ValueOrDie(); for (auto func : module.getOps()) { - InferShapeUntilFixPoint(&func.getBody(), producer); - // TODO(yuanzx): Verify that it is always fine to refine a function's - // return type, as long as we do not change the argument shapes. - InferShapeForFunctionType(func); + if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer, + propagate_caller_callee_constants_))) + return signalPassFailure(); } } + + private: + Option propagate_caller_callee_constants_{ + *this, "propagate-caller-callee-constants", + llvm::cl::desc("Propagate constants between callers and callees"), + llvm::cl::init(true)}; }; PassRegistration pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 6e27823191b..b2203c890e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -254,22 +254,14 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, if (output_buffer_to_size.empty() && arg_no_changed) return success(); // Recreate the If op. auto new_if_operands = llvm::to_vector<8>(if_op.getOperands()); - auto new_output_shapes = llvm::to_vector<8>(if_op.output_shapes().getValue()); for (int64_t i = 1; i < if_op.getNumOperands(); ++i) { auto it = buffer_to_size->find(if_op.getOperand(i)); if (it == buffer_to_size->end()) continue; new_if_operands.push_back(it->getSecond().size); - if (!new_output_shapes.empty()) { - // Size is a scalar shape. - tensorflow::TensorShapeProto shape_proto; - new_output_shapes.push_back(builder.getStringAttr( - tensorflow::mangling_util::MangleShape(shape_proto))); - } } auto new_if = OpBuilder(if_op).create( if_op.getLoc(), then_branch.getType().getResults(), new_if_operands, if_op.getAttrs()); - new_if.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_if.getResult(std::get<0>(entry))] = { new_if.getResult(std::get<1>(entry)), std::get<2>(entry)}; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index 141feeb6b24..688f21c1d52 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -14,11 +14,31 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" namespace mlir { namespace TFTPU { @@ -30,30 +50,400 @@ namespace { constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; -struct TPUExtractHeadTailOutsideCompilation - : public PassWrapper { - void runOnFunction() override; -}; +bool HasOutsideCompilationAttribute(Operation* op) { + return op->getAttrOfType(kXlaOutsideCompilationAttr) != nullptr; +} -void TPUExtractHeadTailOutsideCompilation::runOnFunction() { - getFunction().walk([&](tf_device::LaunchOp launch) { - Block& launch_block = launch.GetBody(); - for (auto& op : llvm::make_early_inc_range(launch_block.getOperations())) { - // TODO(b/155115766): Handle outputs that should be inputs to TPU - // LaunchOp. - if (auto attr = - op.getAttrOfType(kXlaOutsideCompilationAttr)) { - op.moveBefore(launch); - } else { +// Finds op that created a given value. If the value is a BlockArgument, this +// returns the owner of the Block. +Operation* GetOpOfValue(Value value) { + if (auto block_arg = value.dyn_cast()) + return block_arg.getOwner()->getParentOp(); + + return value.getDefiningOp(); +} + +// Checks if `op` is nested in `block`. +bool OpInBlock(Operation* op, Block* block) { + Block* op_block = op->getBlock(); + while (op_block) { + if (op_block == block) return true; + if (auto* parent_op = op_block->getParentOp()) { + op_block = parent_op->getBlock(); + } else { + break; + } + } + return false; +} + +// Wraps block in a Launch. External uses of ops in the block will be return +// values of the Launch and remapped to the Launch results. If `before` is set +// to true, the Launch is created before `op`. Otherwise the Launch is created +// after `op`. +tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, + bool before, Block* launch_block, + llvm::StringRef host_device) { + // Find results and result types of ops in block that needs to returned. + llvm::SmallVector launch_results; + llvm::SmallVector launch_result_types; + for (Operation& head_outside_compiled_op : *launch_block) { + for (Value result : head_outside_compiled_op.getResults()) { + bool has_external_uses = false; + for (Operation* user : result.getUsers()) { + if (OpInBlock(user, launch_block)) continue; + has_external_uses = true; break; } + if (has_external_uses) { + launch_results.push_back(result); + launch_result_types.push_back(result.getType()); + } } - }); + } + + before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op); + auto launch = builder->create( + op->getLoc(), builder->getStringAttr(host_device), launch_result_types); + launch.body().push_back(launch_block); + + builder->setInsertionPointToEnd(&launch.GetBody()); + builder->create(op->getLoc(), launch_results); + + return launch; +} + +// Parses TPU compilation and execution devices from a TPU cluster and returns +// the host device for the head and tail computations. If the TPU computation is +// replicated, kTPUReplicatedHost is returned instead. +LogicalResult GetHostDeviceForHeadTailComputation( + mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster, + std::string* host_device) { + auto replicate = cluster.getParentOfType(); + if (replicate) { + *host_device = tensorflow::kTPUReplicatedHost; + return success(); + } + + auto num_cores_per_replica_attr = + cluster.getAttrOfType(tensorflow::kNumCoresPerReplicaAttr); + if (!num_cores_per_replica_attr) + return cluster.emitOpError( + "cluster op missing `num_cores_per_replica` attribute"); + + if (num_cores_per_replica_attr.getInt() != 1) + return cluster.emitOpError( + "outside compilation is not supported with model parallelism."); + + auto topology_attr = + cluster.getAttrOfType(tensorflow::kTopologyAttr); + if (!topology_attr) + return cluster.emitOpError("cluster op missing `topology` attribute"); + + auto device_assignment_attr = + cluster.getAttrOfType(tensorflow::kDeviceAssignmentAttr); + if (!device_assignment_attr) + return cluster.emitOpError(llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); + + auto status_or_device_coodinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + + if (!status_or_device_coodinates.ok()) + return cluster.emitError() + << "error in fetching tpu device coordinates: " + << status_or_device_coodinates.status().error_message(); + + // Determine compilation and execution devices. + auto status_or_tpu_device_assignment = + tensorflow::GetTPUCompilationAndExecutionDevices( + devices.device_names(), /*num_replicas=*/1, + /*num_cores_per_replica=*/1, topology_attr.getValue(), + status_or_device_coodinates.ConsumeValueOrDie()); + if (!status_or_tpu_device_assignment.ok()) + return cluster.emitError() + << "error in fetching TPU compilation/execution devices: " + << status_or_tpu_device_assignment.status().error_message(); + auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie(); + + *host_device = tpu_device_assignment.tpu_devices[0][0].host; + return success(); +} + +// Returns a set of ops that are outside compiled and can be extracted to before +// the TPU computation. These ops are either connected to the inputs of the TPU +// computation or other ops that can be extracted, and have no operands from +// other ops in the TPU computation that cannot be extracted. +llvm::SmallVector FindOutsideCompiledOpsAtHead( + tf_device::ClusterOp cluster) { + Region* cluster_region = &cluster.body(); + llvm::SmallSetVector head_outside_compiled_ops; + + auto cluster_ops = cluster.GetBody().without_terminator(); + for (Operation& cluster_op : cluster_ops) { + if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + // An outside compiled op can be extracted if its operands are not from + // other ops in the cluster that cannot be extracted. + auto walk_result = cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (head_outside_compiled_ops.count(operand_op)) continue; + + if (operand_op->getParentRegion() == cluster_region) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (!walk_result.wasInterrupted()) + head_outside_compiled_ops.insert(&cluster_op); + } + + return head_outside_compiled_ops.takeVector(); +} + +// Moves head outside compiled ops into its own `tf_device.LaunchOp` +// computation before the cluster. +void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef head_outside_compiled_ops, + llvm::StringRef host_device) { + Block* launch_block = new Block; + for (Operation* head_outside_compiled_op : head_outside_compiled_ops) + head_outside_compiled_op->moveBefore(launch_block, launch_block->end()); + + tf_device::LaunchOp launch = CreateLaunchForBlock( + builder, cluster, /*before=*/true, launch_block, host_device); + + for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(), + launch.getResults())) + replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), + cluster.body()); +} + +// Extracts and move outside compiled ops that have no dependencies in the +// cluster to before the cluster. +mlir::LogicalResult LiftHeadOutsideCompiledOps( + OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, + tf_device::ClusterOp cluster, std::string* host_device, + bool* cluster_updated) { + llvm::SmallVector head_outside_compiled_ops = + FindOutsideCompiledOpsAtHead(cluster); + if (head_outside_compiled_ops.empty()) return success(); + if (failed( + GetHostDeviceForHeadTailComputation(devices, cluster, host_device))) + return failure(); + + CreateHeadComputation(builder, cluster, head_outside_compiled_ops, + *host_device); + + *cluster_updated = true; + return success(); +} + +// Fills `tail_outside_compiled_ops` with ops that are outside compiled and +// can be extracted to after the TPU computation, and `cluster_results` with new +// results of the cluster. These ops are either connected to the output of the +// TPU computation or other ops that can be extracted, and have no results used +// by other ops in the TPU computation that cannot be extracted. +void FindOutsideCompiledOpsAtTailAndClusterResults( + tf_device::ClusterOp cluster, + llvm::SmallVectorImpl* tail_outside_compiled_ops, + llvm::SmallVectorImpl* cluster_results) { + Region* cluster_region = &cluster.body(); + llvm::SmallSetVector tail_outside_compiled_ops_set; + Operation* terminator = cluster.GetBody().getTerminator(); + llvm::SmallSetVector cluster_results_set; + cluster_results_set.insert(terminator->getOperands().begin(), + terminator->getOperands().end()); + + auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator()); + for (Operation& cluster_op : cluster_ops) { + if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + + llvm::SmallVector results_to_forward; + bool can_be_extracted = + llvm::all_of(cluster_op.getUsers(), [&](Operation* op) { + return op == terminator || tail_outside_compiled_ops_set.count(op); + }); + if (!can_be_extracted) continue; + + // Collect operands of cluster op that are generated within the cluster. + // These values should be returned by the cluster. + cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (operand_op->getParentRegion() == cluster_region) + cluster_results_set.insert(operand); + } + }); + + // Remove results of op to be extracted as there are no uses in the cluster. + for (Value result : cluster_op.getResults()) + cluster_results_set.remove(result); + tail_outside_compiled_ops_set.insert(&cluster_op); + } + + *tail_outside_compiled_ops = tail_outside_compiled_ops_set.takeVector(); + *cluster_results = cluster_results_set.takeVector(); +} + +// Moves tail outside compiled ops into its own `tf_device.LaunchOp` +// computation after the cluster. +void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef tail_outside_compiled_ops, + llvm::StringRef host_device) { + Block* launch_block = new Block; + for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) + tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin()); + + tf_device::LaunchOp launch = CreateLaunchForBlock( + builder, cluster, /*before=*/false, launch_block, host_device); + + auto operand_not_in_launch = [&](OpOperand& operand) { + return !launch.getOperation()->isProperAncestor(operand.getOwner()); + }; + for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(), + launch.getResults())) + std::get<0>(result).replaceUsesWithIf(std::get<1>(result), + operand_not_in_launch); +} + +// Updates cluster with updated cluster results after extracting tail outside +// compiled ops. +tf_device::ClusterOp UpdateClusterResults( + OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef new_cluster_results) { + Operation* old_terminator = cluster.GetBody().getTerminator(); + builder->setInsertionPoint(old_terminator); + builder->create(old_terminator->getLoc(), + new_cluster_results); + old_terminator->erase(); + + builder->setInsertionPoint(cluster); + llvm::SmallVector new_cluster_result_types; + new_cluster_result_types.reserve(new_cluster_results.size()); + for (const auto& new_cluster_result : new_cluster_results) + new_cluster_result_types.push_back(new_cluster_result.getType()); + + auto new_cluster = builder->create( + cluster.getLoc(), new_cluster_result_types, + /*operands=*/llvm::ArrayRef{}, cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); + + auto operand_not_in_cluster = [&](OpOperand& operand) { + return !new_cluster.getOperation()->isProperAncestor(operand.getOwner()); + }; + for (auto result : + llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(), + new_cluster.getResults())) + std::get<0>(result).replaceUsesWithIf(std::get<1>(result), + operand_not_in_cluster); + + cluster.erase(); + return new_cluster; +} + +// Extracts and move outside compiled ops that do not create dependencies in the +// cluster to after the cluster. +mlir::LogicalResult LiftTailOutsideCompiledOps( + OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, + std::string host_device, tf_device::ClusterOp* cluster, + bool* cluster_updated) { + llvm::SmallVector tail_outside_compiled_ops; + llvm::SmallVector cluster_results; + FindOutsideCompiledOpsAtTailAndClusterResults( + *cluster, &tail_outside_compiled_ops, &cluster_results); + if (tail_outside_compiled_ops.empty()) return success(); + + if (host_device.empty()) + if (failed(GetHostDeviceForHeadTailComputation(devices, *cluster, + &host_device))) + return failure(); + + // Forward all results of cluster first. These results will be remapped once + // a new cluster is formed. + cluster->replaceAllUsesWith( + cluster->GetBody().getTerminator()->getOperands()); + + CreateTailComputation(builder, *cluster, tail_outside_compiled_ops, + host_device); + + *cluster = UpdateClusterResults(builder, *cluster, cluster_results); + + *cluster_updated = true; + return success(); +} + +// Removes aliased outputs in cluster from ops outside of cluster. +void RemoveClusterAliasedOutputs(OpBuilder* builder, + tf_device::ClusterOp cluster) { + llvm::SmallVector used_old_cluster_results; + llvm::SmallVector new_cluster_results; + llvm::SmallVector new_cluster_result_types; + Operation* cluster_terminator = cluster.GetBody().getTerminator(); + for (auto result : + llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) { + Value cluster_terminator_operand = std::get<0>(result); + if (cluster.getOperation()->isProperAncestor( + cluster_terminator_operand.getDefiningOp())) { + new_cluster_results.push_back(cluster_terminator_operand); + new_cluster_result_types.push_back(cluster_terminator_operand.getType()); + used_old_cluster_results.push_back(std::get<1>(result)); + } else { + std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand); + } + } + + if (new_cluster_results.size() == cluster.getNumResults()) return; + + builder->setInsertionPoint(cluster); + auto new_cluster = builder->create( + cluster.getLoc(), new_cluster_result_types, + /*operands=*/llvm::ArrayRef{}, cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); + new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results); + + for (auto result : + llvm::zip(used_old_cluster_results, new_cluster.getResults())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); + + cluster.erase(); +} + +struct TPUExtractHeadTailOutsideCompilation + : public PassWrapper> { + void runOnOperation() override; +}; + +void TPUExtractHeadTailOutsideCompilation::runOnOperation() { + // Get runtime devices information from the closest parent module. + auto module = getOperation(); + mlir::TF::RuntimeDevices devices; + if (failed(tensorflow::GetDevicesFromOp(module, &devices))) + return signalPassFailure(); + + OpBuilder builder(&getContext()); + llvm::SmallVector clusters; + module.walk( + [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); + + for (tf_device::ClusterOp cluster : clusters) { + std::string host_device; + bool cluster_updated = false; + if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster, + &host_device, &cluster_updated)) || + failed(LiftTailOutsideCompiledOps(&builder, devices, host_device, + &cluster, &cluster_updated))) + return signalPassFailure(); + if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster); + } } } // anonymous namespace -std::unique_ptr> +std::unique_ptr> CreateTPUExtractHeadTailOutsideCompilationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index 4e20cd9d64b..93e5cc22c30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -19,22 +19,27 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/core/platform/logging.h" namespace mlir { namespace TFTPU { namespace { -constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; +constexpr char kAncestorsAttr[] = "ancestors"; constexpr char kDeviceAttr[] = "device"; +constexpr char kKeyAttr[] = "key"; +constexpr char kShapesAttr[] = "shapes"; +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; // Mapping for `_xla_outside_compilation` attribute to ops of a cluster. -using ClusterMap = +using OutsideClusterMap = llvm::SmallDenseMap, 8>; // This pass extracts a CPU computation cluster with `_xla_outside_compilation` @@ -51,7 +56,8 @@ struct TPUExtractOutsideCompilation // Collects and clusters ops in `block` with the same `_xla_outside_compilation` // attribute into `clusters` This returns an error if a // `_xla_outside_compilation` attribute of an op is empty. -LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) { +LogicalResult CollectAndGroupOutsideClusterOps(Block* block, + OutsideClusterMap* clusters) { for (Operation& op : *block) { if (auto attr = op.getAttrOfType(kXlaOutsideCompilationAttr)) { if (attr.getValue().empty()) @@ -67,7 +73,7 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) { } // Moves `cluster_ops` to associated `launch_op` body. -void MoveClusterOpsToLaunchOp( +void MoveOutsideClusterOpsToLaunchOp( tf_device::LaunchOp launch_op, const llvm::SmallVector& cluster_ops) { MLIRContext* context = launch_op.getContext(); @@ -84,8 +90,8 @@ void MoveClusterOpsToLaunchOp( } // Creates a `tf_device::LaunchOp` to wrap cluster ops. -tf_device::LaunchOp CreateLaunchOpForCluster(OpBuilder* builder, - Operation* last_cluster_op) { +tf_device::LaunchOp CreateLaunchOpForOutsideCluster( + OpBuilder* builder, Operation* last_cluster_op) { // TODO(b/154363171): Set the CPU device. // An empty string placeholder is used for the device as that will be later // populated with the device of the associated TPUReplicateMetadata op. @@ -115,16 +121,159 @@ void PropagateParallelExecuteReturnToReplicate( parallel_execute_op.execute_outputs()); } +// Extracts all externally provided operands of `cluster_ops`. +llvm::SmallSetVector GetExternalOperands( + const llvm::SmallVector& cluster_ops) { + llvm::SmallSetVector external_values; + + for (Operation* op : cluster_ops) { + for (Value v : op->getOperands()) { + Operation* defining_op = v.getDefiningOp(); + if (!defining_op) continue; + bool is_external = llvm::none_of(cluster_ops, [&](Operation* cluster_op) { + return defining_op == cluster_op; + }); + + if (is_external) external_values.insert(v); + } + } + + return external_values; +} + +// Extracts all externally used outputs of `cluster_ops`. +llvm::SmallVector GetExternalOutputs( + const llvm::SmallVector& cluster_ops) { + llvm::SmallSetVector external_outputs; + + for (Operation* op : cluster_ops) { + for (Operation* user : op->getUsers()) { + bool is_external = llvm::none_of(cluster_ops, [&](Operation* cluster_op) { + return user == cluster_op; + }); + if (!is_external) continue; + for (Value v : user->getOperands()) { + if (v.getDefiningOp() == op) external_outputs.insert(v); + } + } + } + + return external_outputs.takeVector(); +} + +// Sets the insertion point on `builder` for HostCompute op. Sets insertion +// point to the first op in `cluster_ops` that has one of `external_inputs` +// as an operand. If there are no external_inputs, set insertion point to first +// cluster_op. +void SetHostComputeInsertion( + OpBuilder* builder, const llvm::SmallVector& cluster_ops, + const llvm::SmallSetVector& external_inputs) { + if (external_inputs.empty()) builder->setInsertionPoint(cluster_ops.front()); + for (const auto& cluster_op : cluster_ops) { + for (Value v : cluster_op->getOperands()) { + if (external_inputs.count(v)) { + builder->setInsertionPoint(cluster_op); + return; + } + } + } +} + +// Creates the HostCompute with `inputs` and `outputs` +// using `communication_key`. +TF::_HostComputeMlirOp CreateHostCompute( + OpBuilder* builder, tf_device::ClusterOp tpu_cluster, + const llvm::SmallVector& cluster_ops, + const llvm::SmallSetVector& inputs, llvm::ArrayRef outputs, + const std::string& communication_key) { + llvm::SmallVector device_output_types; + for (const auto& output : outputs) + device_output_types.push_back(output.getType()); + SetHostComputeInsertion(builder, cluster_ops, inputs); + auto host_compute = builder->create( + tpu_cluster.getLoc(), device_output_types, inputs.getArrayRef(), + llvm::ArrayRef{}); + host_compute.setAttr(kAncestorsAttr, builder->getArrayAttr({})); + host_compute.setAttr(kShapesAttr, builder->getArrayAttr({})); + host_compute.setAttr(kKeyAttr, builder->getStringAttr(communication_key)); + return host_compute; +} + +void MoveOutsideCompiledOps( + tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name, + tf_device::LaunchOp host_launch_op, + const llvm::SmallVector& cluster_ops, + const llvm::SmallSetVector& external_inputs, + const llvm::SmallVector& external_outputs) { + if (external_inputs.empty() && external_outputs.empty()) { + MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + return; + } + + OpBuilder builder(host_launch_op.GetBody().getTerminator()); + auto result_type = + RankedTensorType::get({}, builder.getType()); + + std::string txt_metadata; + std::string txt_module; + // TODO(b/157054714): Use a better abstraction instead of _TPUCompileMlirOp + // and _XlaRecvAtHostOp and _XlaSendFromHostOp. + + // A placeholder _TpuCompileMlirOp is created because it is required input to + // XlaRecvAtHostOp and XlaSendFromHostOp but the _TpuCompileMlirOp has not yet + // been created for the TPU cluster that contains the outside compiled ops. + // This placeholder should be replaced by the TPU cluster _TPUCompileMlirOp in + // a subsequent pass. + auto compile_op = builder.create( + tpu_cluster.getLoc(), /*compilation_status=*/result_type, /*program=*/ + llvm::ArrayRef{result_type}, llvm::ArrayRef{}, txt_module, + txt_metadata); + + llvm::SmallVector host_output_types; + for (const auto& external_input : external_inputs) + host_output_types.push_back(external_input.getType()); + + std::string communication_key = + llvm::formatv("host_compute_channel_{0}", outside_cluster_name).str(); + // XlaRecvAtHostOp takes both the program key(dynamic_key) from the + // _TpuCompileMlirOp and the communication_key. + auto recv_at_host = builder.create( + tpu_cluster.getLoc(), host_output_types, + /*dynamic_key=*/compile_op.getResult(1), + builder.getStringAttr(communication_key), + builder.getIntegerAttr(builder.getIntegerType(64), 0)); + + auto host_compute = + CreateHostCompute(&builder, tpu_cluster, cluster_ops, external_inputs, + external_outputs, communication_key); + MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + + builder.setInsertionPoint(host_launch_op.GetBody().getTerminator()); + builder.create( + tpu_cluster.getLoc(), external_outputs, + /*dynamic_key=*/compile_op.getResult(1), + builder.getStringAttr(communication_key), + /*device_ordinal=*/builder.getIntegerAttr(builder.getIntegerType(64), 0)); + + for (auto result : llvm::zip(external_inputs, recv_at_host.getResults())) + mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), + host_launch_op.body()); + + for (auto result : llvm::zip(external_outputs, host_compute.getResults())) + mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), + tpu_cluster.body()); +} + // Creates a `parallel_execute` op in place of launch with 'clusters` and // 'launch` as regions. -void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch, - const ClusterMap& clusters) { - OpBuilder builder(launch); +void CreateParallelExecuteFromOutsideClusters( + tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) { + OpBuilder builder(tpu_cluster); // Create parallel_execute regions. The original TPU cluster computation // is the extra region. - int num_regions = 1 + clusters.size(); + const int num_regions = 1 + clusters.size(); auto parallel_execute_op = builder.create( - launch.getLoc(), num_regions, launch.results().getTypes()); + tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes()); // Move outside compilation clusters to parallel_execute regions. for (const auto& cluster : llvm::enumerate(clusters)) { @@ -133,41 +282,48 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch, Block& outside_block = parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); builder.setInsertionPointToEnd(&outside_block); - tf_device::LaunchOp launch_op = - CreateLaunchOpForCluster(&builder, cluster_ops.back()); - MoveClusterOpsToLaunchOp(launch_op, cluster_ops); + tf_device::LaunchOp host_launch_op = + CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back()); + + // Determine if there are any inputs that are provided out of cluster. + auto external_inputs = GetExternalOperands(cluster_ops); + auto external_outputs = GetExternalOutputs(cluster_ops); + + MoveOutsideCompiledOps(tpu_cluster, cluster.value().getFirst(), + host_launch_op, cluster_ops, external_inputs, + external_outputs); + builder.setInsertionPointToEnd(&outside_block); - // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute - // regions either through communication with TPU parallel_execute regions - // or modifying parallel_execute returns. - builder.create(launch.getLoc(), ArrayRef{}); + builder.create(tpu_cluster.getLoc(), + ArrayRef{}); } // Move the launch body to last parallel_execute block. - Block& inside_block = + Block& parallel_execute_tpu_block = parallel_execute_op.GetRegionBlockWithIndex(num_regions - 1); - builder.setInsertionPointToEnd(&inside_block); - builder.create(launch.getLoc(), launch.getResults()); - launch.getOperation()->moveBefore(inside_block.getTerminator()); + builder.setInsertionPointToEnd(¶llel_execute_tpu_block); + builder.create(tpu_cluster.getLoc(), + tpu_cluster.getResults()); + tpu_cluster.getOperation()->moveBefore( + parallel_execute_tpu_block.getTerminator()); PropagateParallelExecuteReturnToReplicate(parallel_execute_op); - // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute - // regions either through communication with TPU parallel_execute regions - // or modifying parallel_execute returns. } void TPUExtractOutsideCompilation::runOnFunction() { - auto extract_result = getFunction().walk([&](tf_device::LaunchOp launch) { - ClusterMap clusters; - if (failed(CollectAndGroupClusterOps(&launch.GetBody(), &clusters))) - return WalkResult::interrupt(); + auto extract_result = + getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { + OutsideClusterMap clusters; + if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), + &clusters))) + return WalkResult::interrupt(); - if (clusters.empty()) return WalkResult::advance(); + if (clusters.empty()) return WalkResult::advance(); - CreateParallelExecuteFromClusters(launch, clusters); + CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters); - return WalkResult::advance(); - }); + return WalkResult::advance(); + }); if (extract_result.wasInterrupted()) return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 98ff0de7645..696882cd105 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -64,19 +64,14 @@ static llvm::cl::opt tpu_compile_metadata_debug( "'tf._TPUCompileMlir' op as a proto debug string")); constexpr char kNumReplicasAttr[] = "num_replicas"; -constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica"; constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; constexpr char kPaddingMapAttr[] = "padding_map"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; constexpr char kDeviceAttr[] = "device"; constexpr char kDevicesAttr[] = "devices"; constexpr char kVersionsAttr[] = "tf.versions"; constexpr char kBadStringArrayElementMsg[] = "bad '{0}' attribute at index {1}, not a string"; -constexpr char kBadIntArrayElementMsg[] = - "bad '{0}' attribute at index {1}, not an int"; constexpr char kBadArrayElementMsg[] = "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}"; constexpr char kBadArrayAttrLengthMsg[] = @@ -92,7 +87,7 @@ constexpr char kBadArrayAttrLengthMsg[] = // // Would become following ops (unimportant attributes, types are omitted): // %1 = "tf.Shape"(%0) -// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = ""} +// %2:2 = "tf._TPUCompileMlir"(%1) {module = ""} // "tf.TPUCompileSucceededAssert"(%2#0) // %3 = "tf.TPUExecute"(%0, %2#1) // %4 = "tf.SomeOp"(%3) @@ -163,32 +158,6 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, return success(); } -// Extracts device coordinates from a device assignment attribute on an op. -LogicalResult GetDeviceCoordinates( - tf_device::ClusterFuncOp op, - llvm::SmallVectorImpl* device_assignment) { - auto device_assignment_attr = - op.getAttrOfType(kDeviceAssignmentAttr); - if (!device_assignment_attr) - return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr)); - - device_assignment->reserve(device_assignment_attr.size()); - - for (auto device_coordinate_and_idx : - llvm::enumerate(device_assignment_attr)) { - auto device_coordinate = - device_coordinate_and_idx.value().dyn_cast(); - if (!device_coordinate) - return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg, - kDeviceAssignmentAttr, - device_coordinate_and_idx.index())); - - device_assignment->push_back(device_coordinate.getInt()); - } - - return success(); -} - // Populates a TPUCompileMetadataProto with StepMarkerLocation from a // `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoStepMarkerLocation( @@ -448,25 +417,38 @@ Operation* BuildCompileOp( // core, and all replica devices per core are grouped together. void AssignDevicesToReplicate( tf_device::ReplicateOp replicate, - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, OpBuilder* builder) { if (!replicate) return; - const int num_replicas = execution_devices.size(); - const int num_cores_per_replica = execution_devices.front().size(); + const int num_replicas = tpu_devices.size(); + const int num_cores_per_replica = tpu_devices.front().size(); llvm::SmallVector device_attrs; for (int core = 0; core < num_cores_per_replica; ++core) { llvm::SmallVector devices_by_core; devices_by_core.reserve(num_replicas); for (int replica = 0; replica < num_replicas; ++replica) - devices_by_core.push_back(execution_devices[replica][core]); + devices_by_core.push_back(tpu_devices[replica][core].device); device_attrs.push_back( builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core), builder->getStrArrayAttr(devices_by_core))); } + // For data parallelism, also add replicated host devices, as these are + // necessary for outside compilation. + if (num_cores_per_replica == 1) { + llvm::SmallVector hosts; + hosts.reserve(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) + hosts.push_back(tpu_devices[replica][0].host); + + device_attrs.push_back(builder->getNamedAttr( + tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts))); + } + replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs)); } @@ -492,11 +474,12 @@ LogicalResult BuildExecuteOp( // Creates a tf_device.parallel_execute op that wraps TPUExecute op to // represent execution of TPU program in multiple logical cores. LogicalResult BuildParallelExecuteOp( - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, llvm::ArrayRef output_sharding_config, Operation* compile_op, tf_device::ClusterFuncOp cluster_func, OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) { - const int num_cores_per_replica = execution_devices.front().size(); + const int num_cores_per_replica = tpu_devices.front().size(); // parallel_execute op returns concatenated list of return values of // all its regions. // @@ -528,7 +511,7 @@ LogicalResult BuildParallelExecuteOp( num_cores_per_replica, cluster_func, builder, &input_list); if (failed(result)) return failure(); - const bool replicated = execution_devices.size() != 1; + const bool replicated = tpu_devices.size() != 1; // For each logical core, create a region with TPUExecute op. assert(input_list.size() == num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { @@ -553,7 +536,7 @@ LogicalResult BuildParallelExecuteOp( // op. std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(core) - : execution_devices.front()[core]; + : tpu_devices.front()[core].device; auto region_launch_op = WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device); @@ -566,13 +549,14 @@ LogicalResult BuildParallelExecuteOp( } tf_device::LaunchOp AssignDevicesToReplicatedExecute( - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, Operation* execute_op, OpBuilder* builder) { - const bool replicated = execution_devices.size() != 1; + const bool replicated = tpu_devices.size() != 1; // If computation is replicated, use aliased device. Otherwise there is only // one execution device and the device is assigned to the execute op. std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0) - : execution_devices.front().front(); + : tpu_devices.front().front().device; return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device); } @@ -658,27 +642,41 @@ LogicalResult Rewrite( : nullptr; if (replicate) num_replicas = replicate.n().getLimitedValue(); - auto num_cores_per_replica_attr = - cluster_func.getAttrOfType(kNumCoresPerReplicaAttr); + auto num_cores_per_replica_attr = cluster_func.getAttrOfType( + tensorflow::kNumCoresPerReplicaAttr); if (!num_cores_per_replica_attr) return cluster_func.emitOpError( - CreateMissingAttributeMsg(kNumCoresPerReplicaAttr)); + CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr)); int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - auto topology_attr = cluster_func.getAttrOfType(kTopologyAttr); + auto topology_attr = + cluster_func.getAttrOfType(tensorflow::kTopologyAttr); if (!topology_attr) - return cluster_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); + return cluster_func.emitOpError( + CreateMissingAttributeMsg(tensorflow::kTopologyAttr)); - llvm::SmallVector device_assignment; - if (failed(GetDeviceCoordinates(cluster_func, &device_assignment))) - return failure(); + auto device_assignment_attr = cluster_func.getAttrOfType( + tensorflow::kDeviceAssignmentAttr); + if (!device_assignment_attr) + return cluster_func.emitOpError( + llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); + + auto status_or_device_coodinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + if (!status_or_device_coodinates.ok()) + return cluster_func.emitError() + << "error in fetching tpu device coordinates: " + << status_or_device_coodinates.status().error_message(); // Determine compilation and execution devices. auto status_or_tpu_device_assignment = tensorflow::GetTPUCompilationAndExecutionDevices( devices, num_replicas, num_cores_per_replica, - topology_attr.getValue(), device_assignment); + topology_attr.getValue(), + status_or_device_coodinates.ConsumeValueOrDie()); if (!status_or_tpu_device_assignment.ok()) return cluster_func.emitError() << "error in fetching TPU compilation/execution devices: " @@ -687,12 +685,35 @@ LogicalResult Rewrite( // Create compile op. auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie(); builder->setInsertionPoint(cluster_func); + + // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of + // parallel_execute region if it exists. + if (llvm::isa(cluster_func.getParentOp())) { + // Currently, outside compilation and model parallelism are not supported + // together. + assert(num_cores_per_replica == 1); + builder->setInsertionPoint(cluster_func.getParentOp()); + } + Operation* compile_op = BuildCompileOp( cluster_func, num_replicas, num_cores_per_replica, tpu_device_assignment.compilation_device, std::move(tpu_device_assignment.xla_device_assignment), builder); if (!compile_op) return failure(); + // This replaces _TPUCompileMlir placeholder ops that are required + // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass. + // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp + // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more + // structured lowering. + if (auto parallel_op = llvm::dyn_cast( + cluster_func.getParentOp())) { + parallel_op.walk([&](TF::_TPUCompileMlirOp parallel_compile_op) { + parallel_compile_op.replaceAllUsesWith(compile_op); + parallel_compile_op.erase(); + }); + } + // After rewrite, find if there is a TPUCompilationResultOp in the block with // the same _tpu_replicate attribute and replace it with the result of the // compile op. This op is used as a placeholder to hook during graph creation @@ -704,7 +725,7 @@ LogicalResult Rewrite( BuildTPUCompileSucceededAssertOp( compile_op, tpu_device_assignment.compilation_device, builder); - AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices, + AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices, builder); llvm::SmallVector output_shardings; @@ -712,12 +733,13 @@ LogicalResult Rewrite( num_cores_per_replica, cluster_func, &output_shardings); if (failed(result)) return failure(); + builder->setInsertionPoint(cluster_func); if (num_cores_per_replica > 1) { // For model parallelism, tf_device.parallel_execute is used to express // concurrent device execution across multiple logical devices. tf_device::ParallelExecuteOp execute_op; - result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices, + result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices, output_shardings, compile_op, cluster_func, builder, &execute_op); if (failed(result)) return failure(); @@ -740,7 +762,7 @@ LogicalResult Rewrite( if (failed(result)) return failure(); tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute( - tpu_device_assignment.execution_devices, execute_op, builder); + tpu_device_assignment.tpu_devices, execute_op, builder); cluster_func.replaceAllUsesWith(launch_op); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc new file mode 100644 index 00000000000..7befa68f3d8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -0,0 +1,703 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +constexpr char kDeviceAttr[] = "device"; +typedef std::pair Conv2DWithBlockSize; + +// A pass that applies automatic space to depth transform for the first or +// frontier convolutions consume host inputs on TPU. +// This is done by adding space to depth transform op after host input and +// applying space to depth transform for the first convolution and its backprop +// filter on TPU. +// +// Example: original program: +// +// module { +// func @while_body { +// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}: +// -> tensor<2x224x224x3xf32> +// %device_launch = "tf_device.cluster_func"(%input,...) {func = @_func,...) +// return ... +// } +// func @_func(%input: tensor<2x224x224x3xf32>, +// %filter: tensor<7x7x3x64xf32>) { +// %6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}: +// (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> +// tensor<2x112x112x64xf32> +// } +// } +// +// With this pass, the program will be transformed into: +// module { +// func @while_body { +// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"} +// -> tensor<2x224x224x3xf32> +// %space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}: +// (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> +// %device_launch = "tf_device.cluster_func"(%space_to_depth,...) +// {func = @_func,...) +// return ... +// } +// func @_func(%input: tensor<2x112x112x12xf32>, +// %filter: tensor<7x7x3x64xf32>) { +// %filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter): +// tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32> +// %conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}: +// (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) -> +// tensor<2x112x112x64xf32> +// } +// } +// +// This way, the first convolution with 3 feature dimension will be transformed +// to 12 feature dimension, which has better performance on TPU. +// +// TODO(wangtao): add a pass to check if it is profitable to space to depth +// transform and invoke the transform if it is needed. +struct TPUSpaceToDepthPass + : public PassWrapper> { + void runOnOperation() override; +}; + +// Handle padding before convolution for space to depth transform. +LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { + auto ranked_type = op.input().getType().dyn_cast(); + if (!ranked_type) return failure(); + auto pad_input_shape = ranked_type.getShape(); + Location loc = op.getLoc(); + OpBuilder builder(op); + builder.setInsertionPoint(op); + auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32)); + + // Calculate paddings. + int32_t pad_total = kernel_size - 1; + int32_t pad_beg = (pad_total / 2 + 1) / block_size; + int32_t pad_end = (pad_total / 2) / block_size; + SmallVector values = {0, 0, pad_beg, pad_end, + pad_beg, pad_end, 0, 0}; + auto paddings = DenseIntElementsAttr::get(padding_type, values); + // Update pad_op paddings. + op.setOperand(1, builder.create(loc, paddings)); + + // Set input type. + auto input = op.getOperand(0); + SmallVector transform_shape = { + pad_input_shape[0], pad_input_shape[1] / block_size, + pad_input_shape[2] / block_size, + pad_input_shape[3] * block_size * block_size}; + auto transform_result_type = + RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); + input.setType(transform_result_type); + op.setOperand(0, input); + return success(); +} + +// Handle stride for the first convolution for the transform. +void HandleConv2DStride(TF::Conv2DOp conv2d) { + MLIRContext* context = conv2d.getContext(); + SmallVector values = {1, 1, 1, 1}; + auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { + return IntegerAttr::get(IntegerType::get(64, context), v); + }); + // TODO(b/157276506): change type of strides to DenseElementsAttr + auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context); + conv2d.setAttr("strides", strides); +} + +// Transform input shape for the first convolution. +void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { + auto input = conv2d.input(); + auto input_shape = input.getType().cast().getShape(); + SmallVector transform_shape = { + input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size, + input_shape[3] * block_size * block_size}; + auto transform_result_type = + RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); + input.setType(transform_result_type); +} + +// Add padding for convolution filter for space to depth transform. +TF::PadOp GetPadOpForConv2DFilter(ArrayRef filter_shape, Value filter, + OpBuilder* builder, int32_t pad_h, + int32_t pad_w) { + SmallVector values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0}; + auto padding_type = + RankedTensorType::get({4, 2}, builder->getIntegerType(32)); + auto paddings = DenseIntElementsAttr::get(padding_type, values); + auto paddings_value = builder->create(filter.getLoc(), paddings); + std::vector pad_shape = {filter_shape[0] + pad_h, + filter_shape[1] + pad_w, filter_shape[2], + filter_shape[3]}; + SmallVector expand_shape(pad_shape.begin(), pad_shape.end()); + + auto expand_result_type = + RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter)); + return builder->create(filter.getLoc(), expand_result_type, filter, + paddings_value); +} + +// Create reshape op for space to depth transform. +TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef new_shape, + Value input, OpBuilder* builder) { + auto reshape_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(input)); + auto reshape_type = RankedTensorType::get( + {static_cast(new_shape.size())}, builder->getIntegerType(64)); + auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape); + auto reshape_value = + builder->create(input.getLoc(), reshape_sizes); + return builder->create(input.getLoc(), reshape_result_type, + input, reshape_value); +} + +// Create transpose op for shape to depth transform. +TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) { + SmallVector permutation = {0, 2, 1, 3, 4, 5}; + auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32)); + auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation); + auto permute_value = + builder->create(input.getLoc(), permute_attr); + return builder->create(input.getLoc(), input, permute_value); +} + +void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) { + // For example, if filter shape is [7, 7, 3, 64] with block_size 2, + // will apply below transforms to the filter: + // 1. Pad the filter to [8, 8, 3, 64] + // 2. Reshape to [4, 2, 4, 2, 3, 64] + // 3. Transpose to [4, 4, 2, 2, 3, 64] + // 4. Reshape to [4, 4, 12, 64] + auto filter = conv2d.filter(); + OpBuilder builder(conv2d); + builder.setInsertionPoint(conv2d); + // Book keeping filter information. + auto filter_shape = filter.getType().cast().getShape(); + int64_t height = filter_shape[0]; + int64_t width = filter_shape[1]; + int64_t channel = filter_shape[2]; + int64_t out_channel = filter_shape[3]; + // Value/Op before reshape op. + Value before_reshape_value = filter; + if (height % block_size != 0 || width % block_size != 0) { + // Calculate paddings for height and width. + int32_t pad_h = block_size - height % block_size; + int32_t pad_w = block_size - width % block_size; + auto pad_op = + GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w); + // Update op, height and width before reshape. + before_reshape_value = pad_op; + height = height + pad_h; + width = width + pad_w; + } + + // Reshape. + SmallVector new_shape = { + height / block_size, block_size, width / block_size, + block_size, channel, out_channel}; + auto reshape_op = + GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder); + + // Transpose. + auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op); + + // Reshape Back. + SmallVector final_shape = { + height / block_size, width / block_size, + channel * block_size * block_size, out_channel}; + auto final_reshape_op = + GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder); + // Update filter of Conv2D. + conv2d.setOperand(1, final_reshape_op); +} + +// Create slice op for filter in back prop pass. +TF::SliceOp GetSliceOpForConv2DBackPropFilter( + ArrayRef old_filter_shape, Value input, OpBuilder* builder) { + SmallVector slice_size(old_filter_shape.begin(), + old_filter_shape.end()); + auto slice_result_type = + RankedTensorType::get(slice_size, getElementTypeOrSelf(input)); + auto slice_size_op = builder->create( + input.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({4}, builder->getIntegerType(32)), + old_filter_shape)); + SmallVector slice_start_position = {0, 0, 0, 0}; + auto start_position_type = + RankedTensorType::get({4}, builder->getIntegerType(64)); + auto start_position = builder->create( + input.getLoc(), + DenseIntElementsAttr::get(start_position_type, slice_start_position)); + return builder->create(input.getLoc(), slice_result_type, input, + start_position, slice_size_op); +} + +// Transform Conv2DBackPropFilter for space to depth. +void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, + ArrayRef old_filter_shape, + ArrayRef new_filter_shape, + int64_t block_size) { + OpBuilder builder(backprop); + builder.setInsertionPoint(backprop); + + auto input = backprop.input(); + // Get new filter size from new_filter_shape. + auto new_filter_sizes = builder.create( + backprop.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({4}, builder.getIntegerType(32)), + new_filter_shape)); + + // Set stride to [1, 1, 1, 1]. + MLIRContext* context = backprop.getContext(); + SmallVector values = {1, 1, 1, 1}; + auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { + return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v)); + }); + auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context); + + // new result type. + SmallVector new_shape(new_filter_shape.begin(), + new_filter_shape.end()); + auto new_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(input)); + + // Build new BackPropFilterOp. + auto loc = backprop.getLoc(); + auto new_backprop = builder.create( + loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(), + strides, backprop.use_cudnn_on_gpu(), backprop.padding(), + backprop.explicit_paddings(), backprop.data_format(), + backprop.dilations()); + + // For example, if new filter shape is [4, 4, 12, 64], old filter shape + // is [7, 7, 3, 64] with block_size 2. + // Below transforms will be applied to the filter: + // 1. Reshape to [4, 4, 2, 2, 3, 64]; + // 2. Transpose to [4, 2, 4, 2, 3, 64]; + // 3. Reshape to [8, 8, 3, 64]; + // 4. Slice to [7, 7, 3, 64]. + SmallVector first_reshape_shape = { + new_filter_shape[0], + new_filter_shape[1], + block_size, + block_size, + new_filter_shape[2] / (block_size * block_size), + new_filter_shape[3]}; + auto first_reshape_op = + GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder); + + // Transpose. + auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op); + + // Last Reshape op. + SmallVector last_reshape_shape = { + new_filter_shape[0] * block_size, new_filter_shape[1] * block_size, + new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]}; + auto final_reshape_op = + GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder); + + // create slice op. + auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape, + final_reshape_op, &builder); + + // Update backprop's user with the slice op. + backprop.replaceAllUsesWith(slice_op.getResult()); +} + +// Update func arugument type to have the updated input shape. +void UpdateFuncType(FuncOp func) { + llvm::SmallVector arg_types; + arg_types.reserve(func.getNumArguments()); + for (auto arg : func.getArguments()) arg_types.emplace_back(arg.getType()); + auto terminator = func.front().getTerminator(); + SmallVector result_types(terminator->operand_type_begin(), + terminator->operand_type_end()); + func.setType(FunctionType::get(arg_types, result_types, func.getContext())); +} + +void HandleFuncOp(Operation* op) { + auto func = llvm::cast(op); + UpdateFuncType(func); +} + +// Checks if the input producer op is supported in this transform. Right now, we +// only check if it is a host tf.IteratorGetNext. +bool IsSupportedHostInputOp(Operation* op) { + TF::IteratorGetNextOp iter = llvm::dyn_cast(op); + if (!iter) return false; + auto device = op->getAttrOfType(kDeviceAttr); + if (!device) return false; + tensorflow::DeviceNameUtils::ParsedName parsed_device; + if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(), + &parsed_device)) { + return false; + } + return parsed_device.type == "CPU"; +} + +// Builds a SpaceToDepthOp with the given get_layout op and input. +TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func, + Value input, int32_t block_size, + ArrayRef input_shape) { + auto input_op = input.getDefiningOp(); + OpBuilder builder(input_op); + builder.setInsertionPointAfter(input_op); + SmallVector transform_shape = { + input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size, + input_shape[3] * block_size * block_size}; + auto transform_result_type = + RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); + return builder.create(cluster_func.getLoc(), + transform_result_type, input, + APInt(64, block_size)); +} + +// Performs transformation for a non-replicated input. +TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index, + tf_device::ClusterFuncOp cluster_func, + int32_t block_size, + ArrayRef input_shape) { + auto space_to_depth = + BuildSpaceToDepth(cluster_func, input, block_size, input_shape); + cluster_func.setOperand(index, space_to_depth); + return space_to_depth; +} + +// Performs transformation for replicated inputs. Returns true if this is a +// supported case (thus transform happened). +bool HandleHostReplicatedInputs(int64_t index, + tf_device::ClusterFuncOp cluster_func, + int64_t replicate_arg_index, + tf_device::ReplicateOp replicate, + int32_t block_size) { + // We need to know the devices to copy to. + if (!replicate.devices()) return false; + int64_t num_replicas = replicate.n().getZExtValue(); + // Gets inputs at replicate_arg_index for each replica. + auto inputs = replicate.getOperands() + .drop_front(replicate_arg_index * num_replicas) + .take_front(num_replicas); + for (auto input : inputs) { + auto input_op = input.getDefiningOp(); + if (!input_op || !IsSupportedHostInputOp(input_op)) return false; + } + for (auto entry : llvm::enumerate(inputs)) { + auto ranked_type = entry.value().getType().dyn_cast(); + if (!ranked_type) return false; + auto input_shape = ranked_type.getShape(); + auto space_to_depth = + BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape); + replicate.setOperand(num_replicas * replicate_arg_index + entry.index(), + space_to_depth); + } + return true; +} + +// Performs transformation on a pair of execute and compile ops. The compile +// should not have other uses. +void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, + unsigned arg_num) { + auto maybe_replicate = + llvm::dyn_cast(cluster_func.getParentOp()); + + llvm::SmallVector transform_input_indices; + for (auto input : llvm::enumerate(cluster_func.operands())) { + if (auto block_arg = input.value().dyn_cast()) { + if (block_arg.getArgNumber() != arg_num) continue; + // For a block argument, consider transforms only when it is a replicated + // input (defining ops will be outside the replicate node). + if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) { + HandleHostReplicatedInputs(input.index(), cluster_func, + block_arg.getArgNumber(), maybe_replicate, + block_size); + } + } else { + // For an op output, consider transforms only when 1) there is no + // replicateion or 2) it is outside the replicate node that encloses the + // execute node. (Because if the op is inside replicate, it is probably + // not on the host.) + if (input.index() != arg_num) continue; + auto input_op = input.value().getDefiningOp(); + if (maybe_replicate && + maybe_replicate.body().isAncestor(input_op->getParentRegion())) { + continue; + } + if (!IsSupportedHostInputOp(input_op)) continue; + auto ranked_type = input.value().getType().dyn_cast(); + if (!ranked_type) continue; + auto input_shape = ranked_type.getShape(); + HandleHostInput(input.value(), input.index(), cluster_func, block_size, + input_shape); + } + } +} + +// Check if input shape of convolution is good for space to depth transform. +bool Conv2DInputShapeCanTransform(Value input) { + auto ranked_type = input.getType().dyn_cast(); + if (!ranked_type) return false; + auto input_shape = ranked_type.getShape(); + int32_t batch_size = input_shape[0]; + int32_t channel = input_shape[3]; + if (batch_size > 8 || channel > 8) { + return false; + } + return true; +} + +// Checks if a convoluton can apply SpaceToDepth transform. +// Only the first convolution in the graph whose batch size smaller than 8 +// and its input feature size smaller than 8 can be transformed. +Optional> GetConv2DInputArgNum(TF::Conv2DOp conv2d) { + if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { + return None; + } + auto conv2d_input = conv2d.input(); + if (auto block_arg = conv2d_input.dyn_cast()) { + if (!Conv2DInputShapeCanTransform(conv2d_input)) return None; + int num_users = + std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); + return std::make_pair(block_arg.getArgNumber(), num_users); + } + + if (auto pad_op = llvm::dyn_cast(conv2d_input.getDefiningOp())) { + auto pad_input = pad_op.input(); + if (auto block_arg = pad_input.dyn_cast()) { + if (!Conv2DInputShapeCanTransform(pad_input)) return None; + int num_users = std::distance(block_arg.getUsers().begin(), + block_arg.getUsers().end()); + return std::make_pair(block_arg.getArgNumber(), num_users); + } + } + + return None; +} + +// Apply space to depth transform for the first convolution on TPU device. +void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { + // Check if input and filter type are RankedTensorType. + auto input_tensor_type = + conv2d.input().getType().dyn_cast(); + auto filter_tensor_type = + conv2d.filter().getType().dyn_cast(); + if (!input_tensor_type || !filter_tensor_type) return; + // Book keeping filter shape for padding and backprop filter rewrite. + auto filter_shape = filter_tensor_type.getShape(); + SmallVector old_filter_shape(filter_shape.begin(), + filter_shape.end()); + // Handles input. + auto conv2d_input = conv2d.input(); + if (auto block_arg = conv2d_input.dyn_cast()) { + // Change on device function type/shape. + HandleFuncOp(block_arg.getOwner()->getParentOp()); + } + + if (auto pad_op = dyn_cast_or_null(conv2d_input.getDefiningOp())) { + // Rewrite pad_op before Convolutioin. + if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return; + auto pad_input = pad_op.input(); + if (auto block_arg = pad_input.dyn_cast()) { + // Change on device function type/shape. + HandleFuncOp(block_arg.getOwner()->getParentOp()); + } + } + + // Handle Conv2D input, stride and filter. + HandleConv2DInput(conv2d, block_size); + HandleConv2DStride(conv2d); + HandleConv2DFilter(conv2d, block_size); + + // Book keeping new filter shape for backprop filter rewrite. + // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType. + filter_shape = conv2d.filter().getType().cast().getShape(); + SmallVector new_filter_shape(filter_shape.begin(), + filter_shape.end()); + + // Rewrite Conv2DBackPropFilter after the first convolution. + for (Operation* user : conv2d.getOperation()->getUsers()) { + if (auto backprop = dyn_cast(user)) { + HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape, + block_size); + } + } +} + +// Get block size that is equal to stride from spatial dimension +// from convolution. +// Space to depth transform won't be triggered if block size <= 1. +int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { + SmallVector strides(4, 1); + for (int i = 0; i < 3; ++i) { + strides[i] = conv2d.strides()[i].cast().getInt(); + } + + // Space to depth only supports striding at spatial dimension. + if (strides[0] != 1 || strides[3] != 1) return 1; + + // Space to depth only supports height_stride == width_stride case. + if (strides[1] != strides[2]) return 1; + + return strides[1]; +} + +void TPUSpaceToDepthPass::runOnOperation() { + Optional cluster_func; + // Space to depth only supports training loop. + auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) { + cluster_func = cluster; + return WalkResult::interrupt(); + }); + + // Return if there is no tf_device::ClusterFuncOp in training loop. + if (!func_result.wasInterrupted() || !cluster_func.hasValue()) { + return; + } + + // Get the function on device. + auto device_func = + getOperation().lookupSymbol(cluster_func->getFunc()); + if (!device_func) return; + + TF::Conv2DOp first_conv; + Optional> input_shape; + // A map maps block argument id to the convolutions consumes them. + llvm::SmallDenseMap> + argnum_and_convolutions; + // A map maps block argument id to the number of users. + llvm::SmallDenseMap argnum_num_users; + + // Find out the qualified convolutions and its block argument ids. + auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) { + Optional> arg_num_and_num_users = + GetConv2DInputArgNum(conv2d); + if (arg_num_and_num_users.hasValue()) { + // Get block size for the first convolution. + int64_t block_size = GetConv2DBlockSize(conv2d); + auto arg_num = arg_num_and_num_users.getValue().first; + auto num_users = arg_num_and_num_users.getValue().second; + argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size); + argnum_num_users[arg_num] = num_users; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!conv2d_result.wasInterrupted()) { + return; + } + + // Iterate through block argument and its convolution users. Space to depth + // transform will be applied only if all the below conditions are satisfied: + // 1. All the users of the block argument will lead to convolutions; + // 2. block_size of for the space to depth transform for these convolutions + // are the same; + // 3. block_size of for the space to depth transform for these convolutions + // are larger than 1. + for (auto argnum_and_convolution : argnum_and_convolutions) { + auto arg_num = argnum_and_convolution.getFirst(); + auto conv2d_and_block_sizes = argnum_and_convolution.getSecond(); + // Continue if number of users of the block argment doesn't equal to number + // of transformable convolutions and there is no qualified convolution + // for transform or block size is smaller than 2. + if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() || + conv2d_and_block_sizes.empty()) { + argnum_and_convolutions.erase(arg_num); + continue; + } + int64_t block_size = conv2d_and_block_sizes[0].second; + if (block_size < 2) { + argnum_and_convolutions.erase(arg_num); + continue; + } + // Continue if not all the block sizes for space to depth transform are the + // same. + for (auto conv2d_and_block_size : conv2d_and_block_sizes) { + if (conv2d_and_block_size.second != block_size) { + argnum_and_convolutions.erase(arg_num); + break; + } + } + } + + // If there is no qualified space to depth transform. + if (argnum_and_convolutions.empty()) { + return; + } + + // Apply space to depth transform. + for (auto argnum_and_convolution : argnum_and_convolutions) { + auto conv2d_and_block_sizes = argnum_and_convolution.getSecond(); + int64_t block_size = conv2d_and_block_sizes[0].second; + // Apply space to depth transform to the input on the host. + HandleCluster(cluster_func.getValue(), block_size, + argnum_and_convolution.getFirst()); + // Transform the convolution. + for (auto conv2d_and_block_size : conv2d_and_block_sizes) { + HandleFirstConvolution(conv2d_and_block_size.first, + conv2d_and_block_size.second); + } + } +} + +} // namespace + +std::unique_ptr> CreateTPUSpaceToDepthPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-space-to-depth-pass", + "Adds ops that allow TPU program enable automaic space to depth for the" + "convolution determined at JIT compile time."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 9e8745918e3..ec4a25c6fdd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -229,7 +229,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( mapping.emplace_back(it->second, std::move(while_args)); } // Sort the mapping according to execute operand order. - llvm::sort(mapping); + llvm::sort(mapping, llvm::less_first()); // Populate the `retval_index_for_sharding` field of the argument metadate. for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) { int64_t arg_index = entry.value().cast().getInt(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 75fcede8fbb..8e51f8c9a25 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -70,7 +70,6 @@ using llvm::isa; using mlir::BlockArgument; using mlir::Dialect; using mlir::Operation; -using mlir::OperationState; using mlir::Value; using stream_executor::port::StatusOr; @@ -79,6 +78,9 @@ namespace { constexpr char kInvalidExecutorGraphMsg[] = "Functions must be of a single Graph with single op Islands: "; +constexpr char kDeviceAttr[] = "tf.device"; +constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; + bool IsLegalChar(char c, bool first_char) { if (isalpha(c)) return true; if (isdigit(c)) return true; @@ -267,17 +269,14 @@ StatusOr> Exporter::GetArgumentNode( (*node_def->mutable_attr())["index"] = index_attr; if (auto device_attr = - func.getArgAttrOfType(index, "tf.device")) { + func.getArgAttrOfType(index, kDeviceAttr)) *node_def->mutable_device() = device_attr.getValue().str(); - } - if (auto resource_arg_unique_id_attr = - func.getArgAttrOfType( - index, "tf.resource_arg_unique_id")) { - AttrValue unique_id_attr; - unique_id_attr.set_i(resource_arg_unique_id_attr.getInt()); - (*node_def->mutable_attr())["_resource_arg_unique_id"] = unique_id_attr; - } + llvm::ArrayRef func_arg_i_attrs = + func.getArgAttrs(index); + absl::flat_hash_set attrs_to_ignore = {kDeviceAttr}; + TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, + node_def->mutable_attr())); return node_def; } @@ -682,14 +681,6 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, if (auto attr = function.getAttrOfType(stateful_string)) { func_def.mutable_signature()->set_is_stateful(true); } - for (int64 i = 0; i < function.getNumArguments(); ++i) { - if (auto resource_arg_unique_id_attr = - function.getArgAttrOfType( - i, "tf.resource_arg_unique_id")) { - (*func_def.mutable_resource_arg_unique_id())[i] = - resource_arg_unique_id_attr.getInt(); - } - } // Ignore the gradient and is_stateful attribute on the function as they have // been handled above. @@ -699,7 +690,28 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, function.getDialectAttrs()); TF_RETURN_IF_ERROR( ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr())); - (*flib->add_function()) = func_def; + + for (int i = 0, e = function.getNumArguments(); i < e; ++i) { + if (auto resource_arg_unique_id_attr = + function.getArgAttrOfType( + i, kResourceArgUniqueIdAttr)) { + (*func_def.mutable_resource_arg_unique_id())[i] = + resource_arg_unique_id_attr.getInt(); + } + + llvm::ArrayRef func_arg_i_attrs = + function.getArgAttrs(i); + if (func_arg_i_attrs.empty()) continue; + absl::flat_hash_set attrs_to_ignore = { + kDeviceAttr, kResourceArgUniqueIdAttr}; + FunctionDef::ArgAttrs func_def_arg_i_attrs; + TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, + func_def_arg_i_attrs.mutable_attr())); + if (func_def_arg_i_attrs.attr().empty()) continue; + (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs); + } + + (*flib->add_function()) = std::move(func_def); return Status::OK(); } @@ -782,4 +794,22 @@ StatusOr> ConvertMlirToGraphdef( return graphdef; } +stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( + mlir::FuncOp func, const GraphExportConfig& configs, + FunctionDef* function_def) { + Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf"); + FunctionDefLibrary flib; + TF_RETURN_IF_ERROR( + Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib)); + for (auto& func_def : flib.function()) { + if (func_def.signature().name() == func.getName()) { + *function_def = func_def; + return Status::OK(); + } + } + return errors::InvalidArgument( + "Function couldn't be found in the FunctionDefLibrary after converting " + "from MLIR"); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index 2d522f6031e..a5aebd16146 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -50,6 +51,12 @@ stream_executor::port::Status ConvertMlirToGraph( stream_executor::port::Status ConvertMlirToGraph( mlir::ModuleOp module, const GraphExportConfig& configs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def); + +// Converts an MLIR function and adds it to a FunctionLibraryDefinition. +stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( + mlir::FuncOp func, const GraphExportConfig& configs, + FunctionDef* function_def); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 49be3da912a..3aa700d3718 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -40,7 +40,9 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -57,6 +59,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" @@ -65,6 +68,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -109,6 +113,7 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) { } namespace tensorflow { +using mlir::NamedAttrList; using mlir::TensorType; using mlir::TF::VarHandleOp; using mlir::tf_saved_model::GlobalTensorOp; @@ -128,6 +133,13 @@ bool IsOutputShapesAttribute(const AttrValue& attr_value, attr_value.value_case() == AttrValue::kList; } +bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, + llvm::StringRef attr_name) { + if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes") + return attr_value.value_case() == AttrValue::kList; + return false; +} + // This class is used to generate new MLIR function name strings that are both // unique in the TF function library `flib_` and unique among the name strings // generated by the class object during its lifetime. @@ -191,15 +203,11 @@ class ImporterBase { StatusOr InferLibFunctionType(const FunctionBody& fbody); // Extracts arg and ret nodes from FunctionBody. - // `resource_arg_unique_ids` will be filled with the unique IDs of resource - // variables, as a list of {index, ID} pairs. void GetArgsAndRetsFromFunctionBody( const FunctionBody& fbody, absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes, - absl::InlinedVector* control_ret_nodes, - absl::InlinedVector, 4>* - resource_arg_unique_ids); + absl::InlinedVector* control_ret_nodes); // Prepares converting the graph to an MLIR module. This step removes the // backedges of the graph, orders the nodes and infers the shapes. @@ -213,8 +221,7 @@ class ImporterBase { const absl::InlinedVector& ret_nodes, const absl::InlinedVector& control_ret_nodes, llvm::ArrayRef attrs, - const absl::InlinedVector, 4>& - resource_arg_unique_ids); + bool function_graph); // Finds out the function definition for the given function name from the // graph and converts it to a function of the module. This method is called @@ -306,9 +313,9 @@ class ImporterBase { // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar}, // {base_name.k2 : rfc}}. - Status ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes); + Status ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes); // Helper to create either a tf_executor operation or a TF operation wrapped // in an island. When convert_to_legacy_call is true, converts the operation @@ -974,7 +981,6 @@ StatusOr ImporterBase::InferOutputType(const Node& node, int idx, if (dtype == DT_RESOURCE) { const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes"); const AttrValue* shape_attr = node.attrs().Find("_handle_shapes"); - LOG(INFO) << dtype_attr << " " << shape_attr; if (dtype_attr && shape_attr) { if (dtype_attr->list().type().empty()) { return errors::InvalidArgument( @@ -1089,9 +1095,9 @@ StatusOr ImporterBase::ConvertSubtypes( return subtypes; } -Status ImporterBase::ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes) { +Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes) { TF_ASSIGN_OR_RETURN(auto func_attr, ConvertFunctionCallName(value.func().name())); attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); @@ -1165,8 +1171,18 @@ StatusOr ImporterBase::ConvertAttributeValue( return builder_.getArrayAttr( llvm::makeArrayRef(attrs.begin(), attrs.end())); } - case AttrValue::kFunc: - return errors::Unknown("kFunc type should be handled separately!"); + case AttrValue::kFunc: { + // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. + // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue + // will not use this representation. + NamedAttrList attrs; + for (const auto& func_attr : value.func().attr()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(func_attr.second)); + attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); + } + auto func_attrs = builder_.getDictionaryAttr(attrs); + return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); + } case AttrValue::VALUE_NOT_SET: return builder_.getUnitAttr(); // kPlaceholder is not implemented. @@ -1179,9 +1195,7 @@ StatusOr ImporterBase::ConvertAttributeValue( void ImporterBase::GetArgsAndRetsFromFunctionBody( const FunctionBody& fbody, absl::InlinedVector* arg_nodes, absl::InlinedVector* ret_nodes, - absl::InlinedVector* control_ret_nodes, - absl::InlinedVector, 4>* - resource_arg_unique_ids) { + absl::InlinedVector* control_ret_nodes) { arg_nodes->reserve(fbody.arg_nodes.size()); ret_nodes->reserve(fbody.ret_nodes.size()); for (auto arg : fbody.arg_nodes) { @@ -1190,9 +1204,6 @@ void ImporterBase::GetArgsAndRetsFromFunctionBody( for (auto ret : fbody.ret_nodes) { ret_nodes->emplace_back(ret, 0); } - for (const auto& entry : fbody.fdef.resource_arg_unique_id()) { - resource_arg_unique_ids->push_back(entry); - } *control_ret_nodes = fbody.control_ret_nodes; } @@ -1287,14 +1298,13 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { absl::InlinedVector arg_nodes; absl::InlinedVector ret_nodes; absl::InlinedVector control_ret_nodes; - absl::InlinedVector, 4> resource_arg_unique_ids; GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes, - &control_ret_nodes, &resource_arg_unique_ids); + &control_ret_nodes); TF_RETURN_IF_ERROR(child_importer.Convert( mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, llvm::makeArrayRef(attributes.begin(), attributes.end()), - resource_arg_unique_ids)); + /*function_graph=*/true)); return Status::OK(); } @@ -1396,9 +1406,7 @@ Status ImporterBase::Convert( const absl::InlinedVector& arg_nodes, const absl::InlinedVector& ret_nodes, const absl::InlinedVector& control_ret_nodes, - llvm::ArrayRef attrs, - const absl::InlinedVector, 4>& - resource_arg_unique_ids) { + llvm::ArrayRef attrs, bool function_graph) { // TODO(b/122040776): Uses debug info for FunctionDef. auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), func_name, func_type, attrs); @@ -1424,10 +1432,6 @@ Status ImporterBase::Convert( TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph, func_type.getInputs(), arg_nodes, ret_nodes, control_ret_nodes)); - for (const auto& entry : resource_arg_unique_ids) { - function.setArgAttr(entry.first, "tf.resource_arg_unique_id", - builder_.getI64IntegerAttr(entry.second)); - } // TODO(jpienaar): Update post removing shape_refinier_. if (!specs_.enable_shape_inference) { @@ -1486,6 +1490,22 @@ Status ImporterBase::ConvertFunctionArgAndRets( i, "tf.device", builder_.getStringAttr(arg_node.node->requested_device())); + if (arg_node.node->IsArg()) { + for (const auto& arg_node_attr : arg_node.node->attrs()) { + const auto& key = arg_node_attr.first; + // Only import attributes starting with an underscore. + if (key.empty() || key[0] != '_') continue; + // Ignore shape inference attributes as shape information is already + // populated in the result type. + if (IsOutputShapesAttribute(arg_node_attr.second, key) || + IsResourceOutputShapesAttribute(arg_node_attr.second, key)) + continue; + TF_ASSIGN_OR_RETURN(auto converted_attr, + ConvertAttributeValue(arg_node_attr.second)); + func.setArgAttr(i, llvm::formatv("tf.{0}", key).str(), converted_attr); + } + } + island->dropAllReferences(); island->erase(); } @@ -2095,14 +2115,10 @@ class GraphDefImporter : public ImporterBase { // output nodes, for function graphs. Arguments and return values are // determined by node op type. Type and shape information of the function are // inferred by the shape refiner in ImporterBase. - // `resource_arg_unique_ids` will be filled with the unique IDs of resource - // variables, as a list of {index, ID} pairs. StatusOr GetArgsRetsAndTypesFromFunctionGraph( mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes, - absl::InlinedVector, 4>* - resource_arg_unique_ids); + absl::InlinedVector* ret_nodes); // Finds the graph's target nodes/function's control ret nodes based on // supplied node names in `control_outputs`. If `control_outputs` are not @@ -2130,7 +2146,6 @@ StatusOr GraphDefImporter::Convert( absl::InlinedVector arg_nodes; absl::InlinedVector ret_nodes; absl::InlinedVector control_ret_nodes; - absl::InlinedVector, 4> resource_arg_unique_ids; llvm::SmallVector attrs; if (specs.graph_as_function) { if (specs.prune_unused_nodes || !specs.inputs.empty() || @@ -2139,10 +2154,9 @@ StatusOr GraphDefImporter::Convert( "Pruning of graph is currently unsupported when the main graph is " "converted to a function."); - TF_ASSIGN_OR_RETURN( - func_type, - importer.GetArgsRetsAndTypesFromFunctionGraph( - context, &arg_nodes, &ret_nodes, &resource_arg_unique_ids)); + TF_ASSIGN_OR_RETURN(func_type, + importer.GetArgsRetsAndTypesFromFunctionGraph( + context, &arg_nodes, &ret_nodes)); TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, &control_ret_nodes)); @@ -2210,7 +2224,7 @@ StatusOr GraphDefImporter::Convert( TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, - resource_arg_unique_ids)); + specs.graph_as_function)); return module; } @@ -2327,9 +2341,7 @@ StatusOr GraphDefImporter::InferMainFunctionType( StatusOr GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes, - absl::InlinedVector, 4>* - resource_arg_unique_ids) { + absl::InlinedVector* ret_nodes) { auto add_node = [](Node* node, absl::InlinedVector* nodes) { auto* attr = node->attrs().Find("index"); if (!attr) @@ -2370,12 +2382,6 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( TF_ASSIGN_OR_RETURN(auto type, InferOutputType(*arg_node.node, /*idx=*/0, builder)); arg_types.push_back(type); - tensorflow::int64 resource_arg_unique_id; - if (TryGetNodeAttr(arg_node.node->attrs(), "_resource_arg_unique_id", - &resource_arg_unique_id)) { - resource_arg_unique_ids->emplace_back(arg_node_and_idx.index(), - resource_arg_unique_id); - } } llvm::SmallVector ret_types; @@ -2428,8 +2434,8 @@ class SavedModelObjectGraphImporter : public ImporterBase { // Main entry point: converts all functions in the given meta graph to an MLIR // Module. static StatusOr Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes); + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes); private: explicit SavedModelObjectGraphImporter( @@ -3129,8 +3135,8 @@ Status CreateSavedModelIR( } StatusOr SavedModelObjectGraphImporter::Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes) { + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes) { GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; @@ -3207,17 +3213,20 @@ class SavedModelSignatureDefImporter { public: // Main entry point: converts all functions (specified by SignatureDefs) in // the given meta graph to an MLIR Module. - static StatusOr Convert(const SavedModelBundle& bundle, - mlir::MLIRContext* context) { - SavedModelSignatureDefImporter importer(bundle, context); + static StatusOr Convert( + const SavedModelBundle& bundle, absl::Span exported_names, + mlir::MLIRContext* context) { + SavedModelSignatureDefImporter importer(bundle, exported_names, context); return importer.ConvertSignatures(); } private: SavedModelSignatureDefImporter(const SavedModelBundle& bundle, + absl::Span exported_names, mlir::MLIRContext* context) : bundle_(bundle), + exported_names_(exported_names), module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function @@ -3250,6 +3259,7 @@ class SavedModelSignatureDefImporter { const std::vector>& inputs); const SavedModelBundle& bundle_; + absl::Span exported_names_; mlir::OwningModuleRef module_; }; @@ -3265,6 +3275,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() { GraphDebugInfo debug_info; if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; + llvm::StringSet<> exported_name_set; + exported_name_set.insert(exported_names_.begin(), exported_names_.end()); + for (const auto& key_and_signature_def : signatures) { const std::string& sig_def_key = key_and_signature_def.first; const SignatureDef& signature_def = key_and_signature_def.second; @@ -3274,6 +3287,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() { if (sig_def_key == "__saved_model_init_op") { continue; } + if (!exported_name_set.empty() && + exported_name_set.count(sig_def_key) == 0) { + continue; + } TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, debug_info, flib_def)); @@ -3317,6 +3334,9 @@ Status SavedModelSignatureDefImporter::ConvertSignature( graphdef, &sub_graph_def, /*terminal_nodes=*/{specs.outputs.begin(), specs.outputs.end()})); + // Set the function library definitions in the pruned graphdef. + *sub_graph_def.mutable_library() = flib_def.ToProto(); + // Convert sub-graphdef to sub-graph. GraphConstructorOptions options; options.allow_internal_ops = true; @@ -3556,12 +3576,14 @@ StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { return SavedModelObjectGraphImporter::Convert( - saved_model, context, exported_names, add_default_attributes); + saved_model, exported_names, context, add_default_attributes); } StatusOr ConvertSavedModelV1ToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context) { - return SavedModelSignatureDefImporter::Convert(saved_model, context); + const SavedModelBundle& saved_model, absl::Span exported_names, + mlir::MLIRContext* context) { + return SavedModelSignatureDefImporter::Convert(saved_model, exported_names, + context); } std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 8603eadb487..bdb72345201 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -55,6 +55,7 @@ stream_executor::port::StatusOr ConvertSavedModelToMlir( // expressed with tf_executor dialect. stream_executor::port::StatusOr ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, + absl::Span exported_names, mlir::MLIRContext* context); // Serialize a MLIR module to a string. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 2c7f84d8268..6ada0fec4e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context) { + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context) { tensorflow::SavedModelBundle bundle; tensorflow::SessionOptions session_options; // Force saved model states to be restored to CPU. @@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( return nullptr; } - auto module_or = ConvertSavedModelV1ToMlir(bundle, context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context); if (!module_or.status().ok()) { LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index f498864c8aa..490b7c7d8f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( // given MLIR `context`. mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context); + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index 06805e633e2..d7b511094d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -38,6 +38,7 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback, std::unique_ptr os; std::string filepath; if (CreateFileForDumping(name, &os, &filepath).ok()) print_callback(*os); + VLOG(1) << "Dumped MLIR module to " << filepath; } void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 2374687c920..fd1ba3b1901 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -247,9 +248,10 @@ Status RefineShapes(llvm::ArrayRef arg_shapes, static void RegisterDialects() { static bool init_once = []() { - mlir::registerDialect(); - mlir::registerDialect(); mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); return true; }(); @@ -293,12 +295,22 @@ Status ConvertMLIRToXlaComputation( tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type)); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); + // Run shape inference pass to propagate shapes through tensor_cast operations + // from static to dynamic shapes. This could be generated if the shape + // inference was originally missing in a TF op but the corresponding HLO op + // had static shape after lowering. + tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); + // Run LegalizeTFPass again because the previous legalization passes can // expose more graph pruning and canonicalization opportunities that are // necessary for the second LegalizeTFPass(allow_partial_conversion=false) // invocation. tf2xla.addNestedPass( mlir::xla_hlo::createLegalizeTFPass(false)); + // In order to export to XLA, we must sink constants to control flow regions, + // since XLA uses functional control flow. + tf2xla.addNestedPass( + mlir::xla_hlo::createSinkConstantsToControlFlowPass()); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index fcfef565952..b28f26b6c3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -31,12 +31,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -131,13 +133,21 @@ StatusOr ConvertTensor(const Tensor& input_tensor, case DTYPE: \ return ConvertFlatTensor(input_tensor, type); - // TODO(fengliuai): customize the conversions for more types. + // TODO(fengliuai): customize the conversions for quantized and string types. switch (input_dtype) { CONVERT_FLAT(DT_BOOL, bool) CONVERT_FLAT(DT_FLOAT, float) CONVERT_FLAT(DT_DOUBLE, double) + CONVERT_FLAT(DT_INT8, int8) + CONVERT_FLAT(DT_INT16, int16) CONVERT_FLAT(DT_INT32, int32) CONVERT_FLAT(DT_INT64, int64) + CONVERT_FLAT(DT_UINT8, uint8) + CONVERT_FLAT(DT_UINT16, uint16) + CONVERT_FLAT(DT_UINT32, uint32) + CONVERT_FLAT(DT_UINT64, uint64) + CONVERT_FLAT(DT_COMPLEX64, std::complex) + CONVERT_FLAT(DT_COMPLEX128, std::complex) // BFLOAT16 is a special case that it needs to be cast to double type to // match its storage type. @@ -207,12 +217,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { // Converts an MLIR dense string elements attribute to a TensorFlow tensor // proto. -Status ConvertStringElementsAttr(const DenseStringElementsAttr attr, - TensorProto* output_tensor) { - for (const auto& val : attr.getRawStringData()) { - output_tensor->add_string_val(val.data(), val.size()); +void ConvertStringElementsAttr( + const DenseStringElementsAttr attr, + protobuf::RepeatedPtrField* output) { + for (const auto& val : attr.getRawStringData()) + output->Add({val.data(), val.size()}); +} + +template +void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr, + protobuf::RepeatedField* output) { + for (const auto& val : attr.getValues>()) { + output->Add(val.real()); + output->Add(val.imag()); } - return Status::OK(); } // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. @@ -226,139 +244,80 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr, return InvalidArgument("Unexpected elements attribute type from MLIR."); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the double_val field updated. -Status ConvertDoubleElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_double_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_double_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the float_val field updated. -Status ConvertFloatElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_float_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_float_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the half_val field updated. -Status ConvertHalfElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_half_val( - (*elts.begin()).bitcastToAPInt().getSExtValue()); - } else { - for (const auto& value : elts.getFloatValues()) - output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int_val field updated. -Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - auto elts = attr.dyn_cast(); - if (!elts) { - return ConvertOpaqueElementsAttr(attr, output_tensor); - } - - // Bfloat16 is internally represented as `double` in MLIR. - if (elts.isSplat()) { - double v = elts.getSplatValue(); - bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); +// Converts an MLIR elements attribute and adds it to specified repeated field. +template +void ConvertElementsAttr(const mlir::DenseElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add(attr.getSplatValue()); } else { - for (auto v : elts.getValues()) { + for (auto value : attr.getValues()) output->Add(value); + } +} + +// Converts an MLIR elements attribute containing half values and adds it to +// specified repeated field. +void ConvertHalfElementsAttr(const DenseFPElementsAttr attr, + protobuf::RepeatedField* output_tensor) { + if (attr.isSplat()) { + output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue()); + } else { + for (const llvm::APFloat value : attr.getFloatValues()) + output_tensor->Add(value.bitcastToAPInt().getSExtValue()); + } +} + +// Converts an MLIR elements attribute containing int values and adds it to +// specified repeated field. +void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add((*attr.begin()).getSExtValue()); + } else { + for (const llvm::APInt val : attr) output->Add(val.getSExtValue()); + } +} + +void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr, + protobuf::RepeatedField* output) { + // Bfloat16 is internally represented as `double` in MLIR. + if (attr.isSplat()) { + double v = attr.getSplatValue(); + bfloat16 bf16_val = static_cast(v); + output->Add(absl::bit_cast(bf16_val)); + } else { + for (auto v : attr.getValues()) { bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); + output->Add(absl::bit_cast(bf16_val)); } } - - return Status::OK(); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int64_val field updated. -Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int64_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int64_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with bool_val field updated. -Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - for (const auto& val : elts) { - output_tensor->add_bool_val(val.getBoolValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertToTensorProto(const ElementsAttr attr, - TensorProto* output_tensor) { +Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { auto type = attr.getType(); auto shape = type.getShape(); DataType output_dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype)); - output_tensor->set_dtype(output_dtype); - ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape()); + output->set_dtype(output_dtype); + ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); + + if (attr.isa()) + return ConvertOpaqueElementsAttr(attr.cast(), output); + + auto dense_attr = attr.dyn_cast(); + if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); switch (output_dtype) { case DT_FLOAT: - return ConvertFloatElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_float_val()); + break; case DT_HALF: - // Handles both DenseFPElementsAttr and OpaqueElementsAttr. - return ConvertHalfElementsAttr(attr, output_tensor); + ConvertHalfElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; case DT_DOUBLE: - return ConvertDoubleElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_double_val()); + break; case DT_QUINT8: case DT_UINT8: case DT_INT8: @@ -366,20 +325,40 @@ Status ConvertToTensorProto(const ElementsAttr attr, case DT_UINT16: case DT_INT16: case DT_INT32: - return ConvertIntElementsAttr(attr, output_tensor); + ConvertIntElementsAttr(dense_attr.cast(), + output->mutable_int_val()); + break; + case DT_UINT32: + ConvertElementsAttr(dense_attr, output->mutable_uint32_val()); + break; + case DT_UINT64: + ConvertElementsAttr(dense_attr, output->mutable_uint64_val()); + break; case DT_INT64: - return ConvertInt64ElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_int64_val()); + break; case DT_BOOL: - return ConvertBoolElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_bool_val()); + break; case DT_BFLOAT16: - return ConvertBfloat16ElementsAttr(attr, output_tensor); + ConvertBfloat16ElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; case DT_STRING: - return ConvertStringElementsAttr(attr.cast(), - output_tensor); + ConvertStringElementsAttr(dense_attr.cast(), + output->mutable_string_val()); + break; + case DT_COMPLEX64: + ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val()); + break; + case DT_COMPLEX128: + ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val()); + break; default: - return ConvertOpaqueElementsAttr(attr.cast(), - output_tensor); + return errors::Unimplemented(absl::StrCat("Unimplemented data type ", + DataTypeString(output_dtype))); } + return Status::OK(); } Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index d711c19baae..bf96e3d1df4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include +#include #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -99,48 +100,74 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) { EXPECT_EQ(string_values[3], mlir::StringRef("four")); } -TEST(ConvertTypeToTensorTypeTest, Convert16BitFloats) { +class ConvertTensorTest : public ::testing::Test { + protected: + template + void VerifyConversion(std::initializer_list values, DataType dtype, + mlir::Type expected_ty) { + mlir::Builder b(expected_ty.getContext()); + Tensor tensor(dtype, TensorShape({static_cast(values.size())})); + tensor.flat().setValues(values); + + auto value_or = ConvertTensor(tensor, &b); + TF_ASSERT_OK(value_or.status()); + auto attr = value_or.ValueOrDie(); + + EXPECT_EQ(attr.getType().getElementType(), expected_ty); + + Tensor out; + TF_ASSERT_OK(ConvertToTensor(attr, &out)); + + test::ExpectTensorEqual(tensor, out); + } +}; + +TEST_F(ConvertTensorTest, Simple) { RegisterDialects(); + mlir::MLIRContext context; - mlir::Builder b(&context); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context))); + ASSERT_NO_FATAL_FAILURE( + VerifyConversion({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16, + mlir::FloatType::getBF16(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context))); - { - // Create the sample tensor to convert. - Tensor tensor(DT_HALF, TensorShape({1})); - auto Tt = tensor.flat(); - Tt.setValues({Eigen::half(1.0)}); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT8, mlir::IntegerType::get(8, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT16, mlir::IntegerType::get(16, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT32, mlir::IntegerType::get(32, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT64, mlir::IntegerType::get(64, &context))); - auto value_or = ConvertTensor(tensor, &b); - TF_EXPECT_OK(value_or.status()); - auto attr = value_or.ValueOrDie(); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT8, + mlir::IntegerType::get( + 8, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT16, + mlir::IntegerType::get( + 16, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT32, + mlir::IntegerType::get( + 32, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT64, + mlir::IntegerType::get( + 64, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); - EXPECT_TRUE(attr.isa()); - EXPECT_TRUE(attr.getType().getElementType().isF16()); - - Tensor out; - TF_ASSERT_OK(ConvertToTensor(attr, &out)); - - test::ExpectTensorEqual(tensor, out); - } - - { - // Create the sample tensor to convert. - Tensor tensor(DT_BFLOAT16, TensorShape({2})); - auto Tt = tensor.flat(); - Tt.setValues({bfloat16(1.0), bfloat16(-1.0)}); - - auto value_or = ConvertTensor(tensor, &b); - TF_EXPECT_OK(value_or.status()); - auto attr = value_or.ValueOrDie(); - - EXPECT_TRUE(attr.isa()); - EXPECT_TRUE(attr.getType().getElementType().isBF16()); - - Tensor out; - TF_ASSERT_OK(ConvertToTensor(attr, &out)); - - test::ExpectTensorEqual(tensor, out); - } + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64, + mlir::ComplexType::get(mlir::FloatType::getF32(&context)))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128, + mlir::ComplexType::get(mlir::FloatType::getF64(&context)))); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index cc795259893..4877cbc4a44 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -59,6 +59,18 @@ limitations under the License. namespace tensorflow { namespace { +// static TensorFlow op prefix set. +std::set* GlobalOpPrefixes() { + static std::set* global_op_prefixes = [] { + std::set* result = new std::set; + result->insert("tf."); + result->insert("_tf."); + result->insert("tf_executor."); + return result; + }(); + return global_op_prefixes; +} + // Converts a location to the debug information for the node def. Status ConvertLocation(mlir::Location inst_loc, NodeDef::ExperimentalDebugInfo* debug_info) { @@ -268,8 +280,10 @@ StatusOr GetTensorFlowOpName(llvm::StringRef op_name) { // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We // don't need to consider ".source"/".Source" because the nodes with this // suffix are skipped by the caller and will not be added to the graph. - if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") && - !op_name.consume_front("tf_executor.")) { + auto prefixes = GlobalOpPrefixes(); + if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) { + return op_name.consume_front(prefix); + })) { return errors::FailedPrecondition("op node '", op_name.str(), "' was not a TF op!"); } @@ -506,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) { inst->getName().getStringRef().compare("_tf.LegacyCall") == 0; } +Status AddTensorFlowOpPrefix(std::string prefix) { + GlobalOpPrefixes()->insert(prefix); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 32ed528bd0d..58fe39fa4e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -34,10 +34,17 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/stream_executor/lib/statusor.h" +namespace mlir { +class ShapedType; +} // namespace mlir + namespace tensorflow { using stream_executor::port::StatusOr; +// Add custom op prefix for TensorFlow dialects. +Status AddTensorFlowOpPrefix(std::string); + // Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control // dialect back into a TensorFlow valid op name. StatusOr GetTensorFlowOpName(llvm::StringRef); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 6cf2781e48d..282b7ad3139 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -26,9 +26,9 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -39,6 +39,12 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { + +const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST"; +const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica"; +const char* const kTopologyAttr = "topology"; +const char* const kDeviceAssignmentAttr = "device_assignment"; + // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4 // topology. constexpr int kTPUTopologyRank = 4; @@ -46,8 +52,8 @@ constexpr int kTPUTopologyRank = 4; constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM"; constexpr char kDeviceTPU[] = "TPU"; constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; +constexpr char kBadIntArrayElementMsg[] = + "bad '{0}' attribute at index {1}, not an int"; using Device = DeviceNameUtils::ParsedName; using Devices = llvm::ArrayRef; @@ -164,12 +170,19 @@ std::string GetTPUCompilationDevice(Device system_device) { return DeviceNameUtils::ParsedNameToString(system_device); } +// Finds the host CPU device for a given TPU device. +std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) { + tpu_device.type = DEVICE_CPU; + tpu_device.id = 0; + return DeviceNameUtils::ParsedNameToString(tpu_device); +} + // Determines execution devices when topology and device assignment are not // defined. This is a special case where a single core computation is replicated // to every core in the mesh. TPU devices are simply added to // `execution_devices` of one replica. `num_replicas` must be 1 or the total // number of TPU devices available, and `num_cores_per_replica` must be 1. -StatusOr GetFullMeshTPUExecutionDeviceAssignment( +StatusOr GetFullMeshTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices) { const int num_tasks = tpu_devices.size(); @@ -185,17 +198,18 @@ StatusOr GetFullMeshTPUExecutionDeviceAssignment( "'num_cores_per_replica' must be equal to 1, got ", num_cores_per_replica); - ExecutionDevices execution_devices; - execution_devices.reserve(num_replicas); + TPUDevicesAndHosts devices_and_hosts; + devices_and_hosts.reserve(num_replicas); for (int i = 0; i < num_replicas; ++i) { const int task = i / num_tpus_per_task; const int device = i % num_tpus_per_task; - execution_devices.push_back( - {tensorflow::DeviceNameUtils::ParsedNameToString( - tpu_devices[task][device])}); + const auto& tpu_device = tpu_devices[task][device]; + devices_and_hosts.push_back({TPUDeviceAndHost( + /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device), + /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))}); } - return execution_devices; + return devices_and_hosts; } // Helper struct for keeping track of task and device for an associated TPU @@ -326,7 +340,7 @@ StatusOr> ParseTopologyAttr( // - number of device coordinates (in tuple 3) match number 'num_replicas' * // 'num_cores_per_replica' // - a TPU device associated with each device coordinate -StatusOr> +StatusOr> GetGeneralTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices, @@ -361,9 +375,9 @@ GetGeneralTPUExecutionDeviceAssignment( std::vector used_device_ids( location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1), false); - ExecutionDevices execution_devices( - num_replicas, - llvm::SmallVector(num_cores_per_replica, "")); + TPUDevicesAndHosts devices_and_hosts( + num_replicas, llvm::SmallVector( + num_cores_per_replica, TPUDeviceAndHost())); xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica); int pos = 0; for (int replica = 0; replica < num_replicas; ++replica) { @@ -393,20 +407,43 @@ GetGeneralTPUExecutionDeviceAssignment( used_device_ids[device_id] = true; device_assignment(replica, logical_core) = device_id; - execution_devices[replica][logical_core] = - DeviceNameUtils::ParsedNameToString(tpu_devices[task][device]); + auto& device_and_host = devices_and_hosts[replica][logical_core]; + const auto& tpu_device = tpu_devices[task][device]; + device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device); + device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device); } } xla::DeviceAssignmentProto device_assignment_proto; TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto)); - return std::pair( - std::move(execution_devices), std::move(device_assignment_proto)); + return std::pair( + std::move(devices_and_hosts), std::move(device_assignment_proto)); } } // anonymous namespace +StatusOr> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr) { + llvm::SmallVector device_coordinates; + device_coordinates.reserve(device_assignment_attr.size()); + + for (auto device_coordinate_and_idx : + llvm::enumerate(device_assignment_attr)) { + auto device_coordinate = + device_coordinate_and_idx.value().dyn_cast(); + if (!device_coordinate) + return errors::InvalidArgument( + llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr, + device_coordinate_and_idx.index()) + .str()); + + device_coordinates.push_back(device_coordinate.getInt()); + } + + return device_coordinates; +} + StatusOr GetTPUCompilationAndExecutionDevices( Devices devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index dd296a13f4b..6bb541ab683 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/device_name_utils.h" @@ -30,32 +31,52 @@ limitations under the License. namespace tensorflow { using stream_executor::port::StatusOr; -// TPU devices to be used for execution (e.g. devices for TPUExecute ops). They -// are ordered by `num_replicas` followed by `num_cores_per_replica`. -using ExecutionDevices = - llvm::SmallVector, 8>; +extern const char* const kTPUReplicatedHost; +extern const char* const kNumCoresPerReplicaAttr; +extern const char* const kTopologyAttr; +extern const char* const kDeviceAssignmentAttr; -// TPU compilation device, execution devices, and optionally execution device -// IDs. Execution device IDs are populated if `topology` and `device_assignment` -// are provided. +// A TPU device for execution alongside its associated host CPU device. +struct TPUDeviceAndHost { + TPUDeviceAndHost() {} + TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host) + : device(device), host(host) {} + + std::string device; + std::string host; +}; + +// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and +// their associated host CPU devices (for outside compilation). They are ordered +// by `num_replicas` followed by `num_cores_per_replica`. +using TPUDevicesAndHosts = + llvm::SmallVector, 8>; + +// TPU compilation device, execution and associated host devices, and optionally +// execution device IDs. Execution device IDs are populated if `topology` and +// `device_assignment` are provided. struct TPUDeviceAssignment { TPUDeviceAssignment(llvm::StringRef compilation_device, - ExecutionDevices&& execution_devices) + TPUDevicesAndHosts&& tpu_devices) : compilation_device(compilation_device), - execution_devices(std::move(execution_devices)) {} + tpu_devices(std::move(tpu_devices)) {} TPUDeviceAssignment(llvm::StringRef compilation_device, - ExecutionDevices&& execution_devices, + TPUDevicesAndHosts&& tpu_devices, xla::DeviceAssignmentProto&& xla_device_assignment) : compilation_device(compilation_device), - execution_devices(std::move(execution_devices)), + tpu_devices(std::move(tpu_devices)), xla_device_assignment(std::move(xla_device_assignment)) {} std::string compilation_device; - ExecutionDevices execution_devices; + TPUDevicesAndHosts tpu_devices; llvm::Optional xla_device_assignment; }; +// Extracts device coordinates from a device assignment attribute on an op. +StatusOr> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr); + // Finds the TPU compilation device and execution devices from `devices` for a // TPU computation subgraph. Compilation device is determined from looking up // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 87319f2adeb..a70e93a0195 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" @@ -323,30 +325,46 @@ TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) { TF_ASSERT_OK(status_or.status()); - auto& tpu_device_assignment = status_or.ValueOrDie(); + const auto& tpu_device_assignment = status_or.ValueOrDie(); EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 8); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 1); + const auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 8); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 1); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:0/device:TPU:1"); - EXPECT_EQ(execution_devices[2][0], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][0].device, "/job:worker/replica:0/task:0/device:TPU:2"); - EXPECT_EQ(execution_devices[3][0], + EXPECT_EQ(tpu_devices[2][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][0].device, "/job:worker/replica:0/task:0/device:TPU:3"); - EXPECT_EQ(execution_devices[4][0], + EXPECT_EQ(tpu_devices[3][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[4][0].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[5][0], + EXPECT_EQ(tpu_devices[4][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[5][0].device, "/job:worker/replica:0/task:1/device:TPU:1"); - EXPECT_EQ(execution_devices[6][0], + EXPECT_EQ(tpu_devices[5][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[6][0].device, "/job:worker/replica:0/task:1/device:TPU:2"); - EXPECT_EQ(execution_devices[7][0], + EXPECT_EQ(tpu_devices[6][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[7][0].device, "/job:worker/replica:0/task:1/device:TPU:3"); + EXPECT_EQ(tpu_devices[7][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue()); } @@ -410,30 +428,46 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) { TF_ASSERT_OK(status_or.status()); - auto& tpu_device_assignment = status_or.ValueOrDie(); + const auto& tpu_device_assignment = status_or.ValueOrDie(); EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 4); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 2); + const auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 4); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 2); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[0][1], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][1].device, "/job:worker/replica:0/task:1/device:TPU:3"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:0/device:TPU:1"); - EXPECT_EQ(execution_devices[1][1], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][1].device, "/job:worker/replica:0/task:1/device:TPU:2"); - EXPECT_EQ(execution_devices[2][0], + EXPECT_EQ(tpu_devices[1][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][0].device, "/job:worker/replica:0/task:0/device:TPU:3"); - EXPECT_EQ(execution_devices[2][1], + EXPECT_EQ(tpu_devices[2][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][1].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[3][0], + EXPECT_EQ(tpu_devices[2][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][0].device, "/job:worker/replica:0/task:0/device:TPU:2"); - EXPECT_EQ(execution_devices[3][1], + EXPECT_EQ(tpu_devices[3][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][1].device, "/job:worker/replica:0/task:1/device:TPU:1"); + EXPECT_EQ(tpu_devices[3][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment; ASSERT_TRUE(xla_device_assignment.hasValue()); @@ -511,23 +545,35 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 2); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 3); + auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 2); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 3); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:1/device:TPU:1"); - EXPECT_EQ(execution_devices[0][1], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][1].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[0][2], + EXPECT_EQ(tpu_devices[0][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][2].device, "/job:worker/replica:0/task:2/device:TPU:0"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][2].host, + "/job:worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:2/device:TPU:1"); - EXPECT_EQ(execution_devices[1][1], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][1].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[1][2], + EXPECT_EQ(tpu_devices[1][1].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][2].device, "/job:worker/replica:0/task:0/device:TPU:1"); + EXPECT_EQ(tpu_devices[1][2].host, + "/job:worker/replica:0/task:0/device:CPU:0"); auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment; ASSERT_TRUE(xla_device_assignment.hasValue()); @@ -552,5 +598,29 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { EXPECT_EQ(computation_device_2.replica_device_ids(1), 3); } +TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(status_or_device_coodinates.ok()); + auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie(); + EXPECT_EQ(device_coordinates[0], 1); + EXPECT_EQ(device_coordinates[1], 2); + EXPECT_EQ(device_coordinates[2], 3); +} + +TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(!status_or_device_coodinates.ok()); + EXPECT_EQ(status_or_device_coodinates.status().error_message(), + "bad 'device_assignment' attribute at index 0, not an int"); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 62b862f5e21..2e1528e0d60 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -104,26 +104,24 @@ int main(int argc, char** argv) { return 1; } + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_vector); + if (import_saved_model_object_graph) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); - std::vector exported_names = - absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); mlir::MLIRContext context; auto module = tensorflow::SavedModelObjectGraphToMlirImport( - input_filename, tags, absl::Span(exported_names), - &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); } else if (import_saved_model_signature_defs) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); mlir::MLIRContext context; auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 9b731d2c912..ac629ac4573 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -1,4 +1,5 @@ load("//third_party/mlir:tblgen.bzl", "gentbl") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( default_visibility = ["//visibility:public"], @@ -39,7 +40,7 @@ gentbl( "ir/tfjs_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -131,10 +132,106 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", ], ) + +cc_library( + name = "json_translate_lib", + srcs = [ + "translate/json_translate.cc", + ], + hdrs = [ + "translate/json_translate.h", + ], + deps = [ + ":tensorflow_js", + ":tensorflow_js_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], + alwayslink = 1, +) + +cc_library( + name = "tf_to_tfjs_json", + srcs = ["translate/tf_to_tfjs_json.cc"], + hdrs = [ + "translate/tf_to_tfjs_json.h", + ], + deps = [ + ":json_translate_lib", + ":tfjs_optimize", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +tf_cc_binary( + name = "json_translate", + deps = [ + ":json_translate_lib", + "@llvm-project//mlir:MlirTranslateMain", + ], +) + +filegroup( + name = "tf_tfjs_translate_main", + srcs = [ + "translate/tf_tfjs_translate.cc", + ], +) + +tf_cc_binary( + name = "tf_tfjs_translate", + srcs = [":tf_tfjs_translate_main"], + deps = [ + ":json_translate_lib", + ":tensorflow_js_passes", + ":tf_to_tfjs_json", + ":tfjs_optimize", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h index 318895de79c..9c98c9b0e19 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h @@ -26,8 +26,9 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project + namespace mlir { namespace tfjs { diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td index 172347bc0f5..134aa010d8c 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td @@ -23,7 +23,7 @@ limitations under the License. #define TFJS_DIALECT include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// // TensorFlow.js dialect definitions diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD similarity index 65% rename from tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD rename to tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD index 4faa8d2efe8..5c8d37da2f0 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD @@ -1,11 +1,15 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +licenses(["notice"]) glob_lit_tests( - data = [":test_utilities"], + data = [ + ":test_utilities", + ], driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = ["mlir"], + test_file_exts = [ + "pbtxt", + ], ) # Bundle together all of the test utilities that are used by tests. @@ -13,7 +17,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir:tf-opt", + "//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate", "@llvm-project//llvm:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt new file mode 100644 index 00000000000..f6a324fdc13 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt @@ -0,0 +1,78 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "Add" + op: "Add" + input: "input0" + input: "input1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "Mul" + op: "Mul" + input: "Add" + input: "Add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +versions { + producer: 27 +} + +# CHECK: "name": "input0" +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "input1", +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Add" +# CHECK-NEXT: "op": "AddV2" +# CHECK-NEXT: "input": +# CHECK-NEXT: "input0" +# CHECK-NEXT: "input1" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul1" +# CHECK-NEXT: "op": "Mul" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Add" +# CHECK-NEXT: "Add" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul" +# CHECK-NEXT: "op": "_Retval" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Mul1" +# CHECK: "type": "DT_INT32" +# CHECK: "library" +# CHECK: "versions" +# CHECK: "producer": 27 + diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt new file mode 100644 index 00000000000..810db71f5e0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt @@ -0,0 +1,175 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + experimental_debug_info { + } +} +node { + name: "alpha" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + experimental_debug_info { + } +} +node { + name: "Relu" + op: "Relu" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Neg" + op: "Neg" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Relu1" + op: "Relu" + input: "Neg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Mul" + op: "Mul" + input: "alpha" + input: "Relu1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Add" + op: "Add" + input: "Relu" + input: "Mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "Add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { +} +versions { + producer: 344 +} + +# CHECK: "node": +# CHECK: "name": "input0", +# CHECK-NEXT: "op": "Placeholder", +# CHECK-NEXT: "attr": +# CHECK: "type": "DT_FLOAT" +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul", +# CHECK-NEXT: "op": "Const", +# CHECK-NEXT: "attr": +# CHECK: "value": +# CHECK: "tensor": +# CHECK: "dtype": "DT_FLOAT", +# CHECK: "tensorShape": {}, +# CHECK: "floatVal": +# CHECK: -0.5 +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1", +# CHECK-NEXT: "op": "Prelu", +# CHECK-NEXT: "input": +# CHECK: "input0", +# CHECK: "Add.Relu.Neg.Relu1.Mul" +# CHECK: "attr": +# CHECK: "_output_shapes": +# CHECK: "list": +# CHECK: "shape": +# CHECK: "dim": +# CHECK: "size": "10" +# CHECK: "experimentalDebugInfo": {} +# CHECK: "name": "Add", +# CHECK-NEXT: "op": "_Retval", +# CHECK-NEXT: "input": +# CHECK: "Add.Relu.Neg.Relu1.Mul1" +# CHECK: "attr": +# CHECK: "T": +# CHECK: "type": "DT_FLOAT" +# CHECK: "library": {}, +# CHECK: "versions": +# CHECK: "producer": 344 + diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc index 631bb1ae2af..a445937570e 100644 --- a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc +++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" @@ -47,6 +46,11 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) { // Canonicalize, CSE etc. pm->addNestedPass(mlir::createCanonicalizerPass()); pm->addNestedPass(mlir::createCSEPass()); + + // raise to executor dialect in order to use GraphDef converter + pm->addNestedPass( + mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm->addNestedPass(mlir::CreateBreakUpIslandsPass()); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc new file mode 100644 index 00000000000..7f4b8ffae09 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" + +using mlir::ModuleOp; +using mlir::TranslateFromMLIRRegistration; +using std::string; +using tensorflow::Status; +using xla::StatusOr; + +// Translates the given MLIR module in the TFJS dialect to TFJS JSON +// format. Returns false on success. +// +bool tfjs::MlirToJSONTranslateFunction(ModuleOp module, + std::string* serialized_json) { + string json_output; + // Allow TF to treat TFJS ops as TF ops. + if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) { + LOG(ERROR) << "Failed to add tfjs op prefix."; + return false; + } + tensorflow::GraphExportConfig confs; + confs.export_shapes = true; + confs.export_library = true; + tensorflow::FunctionLibraryDefinition flib_def( + tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); + absl::flat_hash_set control_ret_nodes; + auto graph = absl::make_unique(flib_def); + auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def, + &control_ret_nodes); + if (!status.ok()) { + LOG(ERROR) << "Graph export failed: " << status; + return false; + } + auto graphdef = absl::make_unique(); + graph->ToGraphDef(graphdef.get()); + + // Replace the _Arg nodes of the main function with Placeholder op. + auto nodes = graphdef->mutable_node(); + for (const auto& node : llvm::enumerate(*nodes)) { + if (node.value().op() == "_Arg") { + nodes->Mutable(node.index())->set_op("Placeholder"); + } + } + + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString( + *graphdef, &json_output, json_options); + if (!jsonStatus.ok()) { + LOG(ERROR) << "Proto2Json failed: " << status; + return false; + } + *serialized_json = std::move(json_output); + return true; +} + +static mlir::LogicalResult MlirToJSONFileTranslateFunction( + ModuleOp module, llvm::raw_ostream& output) { + std::string serialized_json; + if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json)) + return mlir::failure(); + + output << serialized_json; + return mlir::success(); +} + +static TranslateFromMLIRRegistration MLIRToJSONFileTranslate( + "mlir-to-tfjs-json", MlirToJSONFileTranslateFunction); diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.h b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h new file mode 100644 index 00000000000..0a931f770ad --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ + +#include + +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/core/lib/core/status.h" + +namespace tfjs { + +// Translates the given MLIR `module` into a JSON string. Returns true if +// translation fails, otherwise returns false. +bool MlirToJSONTranslateFunction(mlir::ModuleOp module, + std::string* serialized_json); +} // namespace tfjs + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc new file mode 100644 index 00000000000..e735a3c7b8c --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc @@ -0,0 +1,173 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/str_split.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h" +#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using llvm::cl::opt; +using mlir::MLIRContext; +using stream_executor::port::StatusOr; + +// NOLINTNEXTLINE +opt input_file_name(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +opt import_saved_model_object_graph( + "savedmodel-objectgraph-to-mlir", + llvm::cl::desc("Import a saved model to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt import_saved_model_signature_defs( + "savedmodel-signaturedefs-to-mlir", + llvm::cl::desc("Import a saved model V1 to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt saved_model_tags( + "tf-savedmodel-tags", + llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, " + "separated by ','"), + llvm::cl::init("serve")); + +// NOLINTNEXTLINE +opt saved_model_exported_names( + "tf-savedmodel-exported-names", + llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty " + "(the default) means export all."), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt output_file_name("o", llvm::cl::desc(""), + llvm::cl::value_desc("filename"), + llvm::cl::init("-")); +// NOLINTNEXTLINE +opt input_mlir( + "input-mlir", + llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of " + "GraphDef format"), + llvm::cl::init(false), llvm::cl::Hidden); +// NOLINTNEXTLINE +opt output_mlir( + "output-mlir", + llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"), + llvm::cl::init(false)); + +// The following approach allows injecting opdefs in addition +// to those that are already part of the global TF registry to be linked in +// prior to importing the graph. The primary goal is for support of custom ops. +// This is not intended to be a general solution for custom ops for the future +// but mainly for supporting older models like mobilenet_ssd. More appropriate +// mechanisms, such as op hints or using functions to represent composable ops +// like https://github.com/tensorflow/community/pull/113 should be encouraged +// going forward. +// NOLINTNEXTLINE +llvm::cl::list custom_opdefs( + "tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing " + "graphdef")); + +// Debugging flag to print function mapping in the JSON. +// NOLINTNEXTLINE +static opt print_function_result_mapping( + "print-function-result-mapping", + llvm::cl::desc( + "Print the mapping of function result to json output buffer"), + llvm::cl::init(false)); + +enum TranslationStatus { kTrSuccess, kTrFailure }; + +static int PrintFunctionResultMapping(const std::string& result) { + std::cout << result << std::endl; + return kTrSuccess; +} + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, + "TF GraphDef to TFJS JSON converter\n"); + + MLIRContext context; + llvm::SourceMgr source_mgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); + + StatusOr module; + + if (import_saved_model_object_graph || import_saved_model_signature_defs) { + if (input_mlir) + module = tensorflow::errors::InvalidArgument( + "Importing saved model should not have input_mlir set"); + module = tensorflow::ImportSavedModel( + import_saved_model_object_graph, import_saved_model_signature_defs, + custom_opdefs, input_file_name, saved_model_tags, + saved_model_exported_names, &context); + } else { + module = tensorflow::LoadFromGraphdefOrMlirSource( + input_file_name, input_mlir, custom_opdefs, debug_info_file, + input_arrays, input_dtypes, input_shapes, output_arrays, + /*prune_unused_nodes=*/true, &source_mgr, &context); + } + + // If errors occur, the library call in the above already logged the error + // message. So we can just return here. + if (!module.ok()) return kTrFailure; + + mlir::PassManager pm(&context); + + tensorflow::AddTFToTFJSConversionPasses(&pm); + + std::string result; + auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(), + output_mlir, &result, &pm); + if (!status.ok()) return kTrFailure; + + std::string error_msg; + auto output = mlir::openOutputFile(output_file_name, &error_msg); + if (output == nullptr) { + llvm::errs() << error_msg << '\n'; + return kTrFailure; + } + output->os() << result; + output->keep(); + + // Print out debugging info related to function mapping. + if (print_function_result_mapping) return PrintFunctionResultMapping(result); + return kTrSuccess; +} diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc new file mode 100644 index 00000000000..7dc9ea049ba --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::OwningModuleRef; +using stream_executor::port::StatusOr; + +namespace { +tensorflow::Status RegisterCustomOps( + const std::vector& extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + tensorflow::OpDef opdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, + &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return errors::InvalidArgument("fail to parse extra OpDef"); + } + // Register extra opdefs. + tensorflow::OpRegistry::Global()->Register( + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { + *op_reg_data = tensorflow::OpRegistrationData(opdef); + return Status::OK(); + }); + } + return Status::OK(); +} +} // namespace + +StatusOr LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, MLIRContext* context) { + // Set up the input file. + std::string error_message; + auto file = mlir::openInputFile(input_filename, &error_message); + if (!file) { + llvm::errs() << error_message << "\n"; + return errors::InvalidArgument("fail to open input file"); + } + + if (input_mlir) { + source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context)); + } + + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + + return tensorflow::GraphdefToMlirTranslateFunction( + file->getBuffer(), debug_info_file, input_arrays, input_dtypes, + input_shapes, output_arrays, /*control_output_arrays=*/"", + prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, + /*graph_as_function=*/false, /*upgrade_legacy=*/true, + /*enable_shape_inference=*/true, context); +} + +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager) { + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), + /*propagate=*/true); + if (failed(pass_manager->run(module))) { + return statusHandler.ConsumeStatus(); + } + + if (export_to_mlir) { + llvm::raw_string_ostream os(*result); + module.print(os); + return Status::OK(); + } + + return tfjs::MlirToJSONTranslateFunction(module, result) + ? Status::OK() + : statusHandler.ConsumeStatus(); +} + +StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context) { + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_in_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_in_vector); + if (import_saved_model) { + auto module = tensorflow::SavedModelObjectGraphToMlirImport( + input_filename, tags, absl::Span(exported_names), context); + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else if (import_saved_model_v1) { + auto module = tensorflow::SavedModelSignatureDefsToMlirImport( + input_filename, tags, exported_names, context); + + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else { + return tensorflow::errors::InvalidArgument( + "Should be either saved model v1 or v2"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h new file mode 100644 index 00000000000..d68f0e7d46e --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR +// source into a MLIR module. If `input_mlir` is true, load from a MLIR source +// file; otherwise, load from a GraphDef. +// Setting prune_unused_nodes to true, would prune unreachable nodes if +// output_arrays is specified. +stream_executor::port::StatusOr +LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); + +// Load Saved model (either v1 or v2) into MLIR. +stream_executor::port::StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context); + +// Taking a MLIR module in TF executor dialect and a set of parameters, +// applies a set of passes to convert the module to TFJS dialect and +// serializes the result to JSON string. +// If `export_to_mlir` is true, the result is exported in MLIR text format, +// otherwise exported in JSON. +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD deleted file mode 100644 index 88e214f601b..00000000000 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ /dev/null @@ -1,156 +0,0 @@ -load("//third_party/mlir:tblgen.bzl", "gentbl") -load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library_cc") - -# TF to TFRT kernels conversion. -package( - default_visibility = [":friends"], - licenses = ["notice"], # Apache 2.0 -) - -package_group( - name = "friends", - packages = [ - "//learning/brain/experimental/tfrt/...", - "//tensorflow/compiler/...", - "//tensorflow/core/runtime_fallback/...", - "//tensorflow/core/tfrt/experimental/saved_model/...", - "//third_party/tf_runtime_google/...", - ], -) - -cc_library( - name = "tf_legalize_to_tfrt", - srcs = [ - "tf_legalize_to_hex.cc", - ], - deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "@com_google_absl//absl/memory", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - ], - alwayslink = 1, -) - -filegroup( - name = "runtime_fallback_ops_td_files", - srcs = [ - "runtime_fallback/runtime_fallback_ops.td", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", - "@tf_runtime//:OpBaseTdFiles", - ], -) - -gentbl( - name = "runtime_fallback_ops_inc_gen", - tbl_outs = [ - ( - "-gen-op-decls", - "runtime_fallback_ops.h.inc", - ), - ( - "-gen-op-defs", - "runtime_fallback_ops.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "runtime_fallback/runtime_fallback_ops.td", - td_includes = [ - "external/tf_runtime/include", - ], - td_srcs = [ - ":runtime_fallback_ops_td_files", - ], -) - -cc_library( - name = "runtime_fallback_opdefs_alwayslink", - srcs = [ - "runtime_fallback/dialect_static_registration.cc", - "runtime_fallback/runtime_fallback_combine.cc", - "runtime_fallback/runtime_fallback_ops.cc", - ], - hdrs = [ - "runtime_fallback/runtime_fallback_ops.h", - ], - deps = [ - ":runtime_fallback_ops_inc_gen", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SideEffects", - "@llvm-project//mlir:Support", - "@tf_runtime//:basic_kernels_opdefs_alwayslink", - "@tf_runtime//:tensor_opdefs_alwayslink", - ], - alwayslink = 1, -) - -cc_library( - name = "lower_tf_to_tfd_alwayslink", - srcs = ["runtime_fallback/lower_tf_to_tfd.cc"], - deps = [ - "runtime_fallback_opdefs_alwayslink", - "//tensorflow/compiler/mlir/tensorflow", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Transforms", - "@tf_runtime//:basic_kernels_opdefs_alwayslink", - ], - alwayslink = 1, -) - -cc_library( - name = "tf_to_corert", - srcs = [ - "transforms/optimize.cc", - "transforms/tf_to_corert.cc", - ], - hdrs = [ - "transforms/passes.h", - ], - deps = [ - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_tensor", - "//tensorflow/core:framework", - "//tensorflow/core/platform:tstring", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Transforms", - "@tf_runtime//:basic_kernels_opdefs_alwayslink", - "@tf_runtime//:core_runtime_opdefs_alwayslink", - ], - alwayslink = 1, -) - -cc_library( - name = "compatibility_analysis", - srcs = [ - "analysis/compatibility_analysis.cc", - ], - hdrs = [ - "analysis/compatibility_analysis.h", - ], - deps = [ - ":analysis/analysis_proto_cc", - ":tf_to_corert", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/core:lib_proto_parsing", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Translation", - ], - alwayslink = 1, -) - -tf_proto_library_cc( - name = "analysis/analysis_proto", - srcs = ["analysis/analysis.proto"], - cc_api_version = 2, -) diff --git a/tensorflow/compiler/mlir/tfrt/analysis/analysis.proto b/tensorflow/compiler/mlir/tfrt/analysis/analysis.proto deleted file mode 100644 index 0716a243bb3..00000000000 --- a/tensorflow/compiler/mlir/tfrt/analysis/analysis.proto +++ /dev/null @@ -1,25 +0,0 @@ -syntax = "proto3"; - -package mlir.tfrt; - -message CompatibilityAnalysisReportProto { - bool unknown_dialect = 1; - bool ref_variable = 2; - bool incompatible_variable = 3; - bool incompatible_attribute = 4; - bool control_flow_v1 = 5; - - // TODO(chky): add more checks, eg. tensor datatypes. -} - -message CompatibilityAnalysisProto { - CompatibilityAnalysisReportProto summary = 1; - - message OpInfo { - int32 count = 1; - - CompatibilityAnalysisReportProto report = 2; - } - - map ops = 2; -} diff --git a/tensorflow/compiler/mlir/tfrt/analysis/compatibility_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/compatibility_analysis.cc deleted file mode 100644 index 7e9c5544c25..00000000000 --- a/tensorflow/compiler/mlir/tfrt/analysis/compatibility_analysis.cc +++ /dev/null @@ -1,193 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tfrt/analysis/compatibility_analysis.h" - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace tensorflow { -namespace { - -class CompatibilityAnalysis { - public: - void AnalyzeOperation(mlir::Operation* op); - - const mlir::tfrt::CompatibilityAnalysisProto& GetResult() const { - return analysis_; - } - - private: - // Return true if some attributes in the op are not supported. - bool AnalyzeOpAttributes(mlir::Operation* op); - // Return true if this op has unsupported operation (eg. mutate) on resource - // variables. - bool AnalyzeVariable(mlir::Operation* op); - - void UpdateReport( - const mlir::tfrt::CompatibilityAnalysisReportProto& new_report, - mlir::tfrt::CompatibilityAnalysisReportProto* old_report); - - mlir::tfrt::CompatibilityAnalysisProto analysis_; -}; - -void CompatibilityAnalysis::AnalyzeOperation(mlir::Operation* op) { - // Skip the standard ops that are allowed in tf dialect. - if (llvm::isa(op) || llvm::isa(op) || - llvm::isa(op) || llvm::isa(op)) - return; - - auto op_name = op->getName(); - - std::string name = op_name.getStringRef().str(); - - mlir::tfrt::CompatibilityAnalysisReportProto op_report; - - if (op_name.getDialect() == - mlir::TF::TensorFlowDialect::getDialectNamespace()) { - // Analyze op attributes. - if (AnalyzeOpAttributes(op)) op_report.set_incompatible_attribute(true); - - // Analyze variable operations. - if (AnalyzeVariable(op)) op_report.set_incompatible_variable(true); - - // Reference variable is not supported. - if (op_name.getStringRef() == "tf.VariableV2") - op_report.set_ref_variable(true); - } else if (op_name.getDialect() == "tf_executor") { - if (llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op)) { - op_report.set_control_flow_v1(true); - } else { - // Skip the rest of the tf_executor ops as they can be handled. - // - // TODO(chky): consider adding whitelist here. - return; - } - } else { - // Mark unknown dialect in the report. - op_report.set_unknown_dialect(true); - } - - auto& op_info = (*analysis_.mutable_ops())[name]; - op_info.set_count(op_info.count() + 1); - - UpdateReport(op_report, op_info.mutable_report()); - UpdateReport(op_report, analysis_.mutable_summary()); -} - -bool CompatibilityAnalysis::AnalyzeOpAttributes(mlir::Operation* op) { - // tf.Const gets special handling so it is always compatible. - if (llvm::isa(op)) return false; - - // TODO(chky): Derived attributes should be also analyzed here. - for (auto attr : op->getAttrs()) { - if (attr.first.strref() == "_output_shapes") continue; - if (attr.first.strref() == "_class") continue; - - // Symbol attributes (eg. function names) is currently not supported. - // - // TODO(chky): CoreRT should ideally support function call operatoins. - // Remove this condition once that is implemented. - if (attr.second.isa()) return true; - - // Currently only tensors of simple dtypes (i1, i32, i64, f32, f64) are - // supported. - if (auto elements_attr = attr.second.dyn_cast()) { - if (!elements_attr.isa()) return true; - auto element_type = elements_attr.getType().getElementType(); - if (element_type.isa()) return true; - } - - // Currently only arrays of simple element types (i1, i32, i64, f32, f64) - // are supported. - if (auto array_attr = attr.second.dyn_cast()) { - if (array_attr.size() > 0) { - if (array_attr[0].isa()) return true; - - if (array_attr[0].isa()) return true; - - if (array_attr[0].isa()) return true; - } - } - } - return false; -} - -bool CompatibilityAnalysis::AnalyzeVariable(mlir::Operation* op) { - // Currently only supported variable op is ReadVariableOp. - if (llvm::isa(op)) return false; - - for (auto value : op->getOperands()) { - auto type = value.getType(); - if (auto tensor_type = type.dyn_cast()) { - auto element_type = tensor_type.getElementType(); - if (element_type.isa()) return true; - } - } - - return false; -} - -void CompatibilityAnalysis::UpdateReport( - const mlir::tfrt::CompatibilityAnalysisReportProto& new_report, - mlir::tfrt::CompatibilityAnalysisReportProto* old_report) { - if (new_report.unknown_dialect()) old_report->set_unknown_dialect(true); - - if (new_report.ref_variable()) old_report->set_ref_variable(true); - - if (new_report.incompatible_variable()) - old_report->set_incompatible_variable(true); - - if (new_report.incompatible_attribute()) - old_report->set_incompatible_attribute(true); - - if (new_report.control_flow_v1()) old_report->set_control_flow_v1(true); -} - -} // namespace - -mlir::tfrt::CompatibilityAnalysisProto AnalyzeTFCompatibility( - mlir::ModuleOp op) { - CompatibilityAnalysis analysis; - op.walk([&analysis](mlir::Operation* op) { analysis.AnalyzeOperation(op); }); - return analysis.GetResult(); -} - -static mlir::TranslateFromMLIRRegistration registration( - "analyze-tf-for-tfrt", [](mlir::ModuleOp op, llvm::raw_ostream& output) { - auto analysis_proto = AnalyzeTFCompatibility(op); - std::string text_proto; - if (tensorflow::protobuf::TextFormat::PrintToString(analysis_proto, - &text_proto)) { - output << text_proto; - return mlir::success(); - } - - return mlir::failure(); - }); - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/lower_tf_to_tfd.cc b/tensorflow/compiler/mlir/tfrt/runtime_fallback/lower_tf_to_tfd.cc deleted file mode 100644 index 5f831c9ef6a..00000000000 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/lower_tf_to_tfd.cc +++ /dev/null @@ -1,390 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h" -#include "tfrt/basic_kernels/opdefs/basic_kernels.h" - -namespace mlir { -namespace { - -constexpr const char kTmpLoweringCastOpName[] = "tmp_lowering_cast_op"; - -static Type GetChainType(MLIRContext* context) { - auto hexDialect = Identifier::get("hex", context); - return OpaqueType::get(hexDialect, "chain", context); -} - -static Type GetTfdTensorType(MLIRContext* context) { - auto tfdDialect = Identifier::get("tfd", context); - return OpaqueType::get(tfdDialect, "tf_tensor", context); -} - -struct TfToTfdLoweringPass - : public PassWrapper> { - void runOnOperation() final; -}; - -class FuncOpSignatureConversion : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - FuncOp funcOp, llvm::ArrayRef operands, - ConversionPatternRewriter& rewriter) const override { - auto ctx = funcOp.getContext(); - auto chain_type = GetChainType(ctx); - auto tfd_tensor_type = GetTfdTensorType(ctx); - FunctionType type = funcOp.getType(); - - // Convert function return results. The lowered function is expected to - // return a chain as the first return result. For each original TF tensor, - // the lowered function returns a TFD tensor instead. - llvm::SmallVector converted_results; - if (type.getNumResults() > 0) { - // Add a chain as the first return result. - converted_results.push_back(chain_type); - - // Convert the original TF tensor return results. - for (unsigned i = 0, e = type.getNumResults(); i != e; ++i) { - if (auto tensor_type = type.getResult(i).dyn_cast()) { - // Each TF tensor is converted to a TFD tensor. - converted_results.push_back(tfd_tensor_type); - } else { - // Only handle TF tensor conversion for now. - return failure(); - } - } - } - - // Create the new function signature. The lowered function is expected to - // take a Chain as the first argument. Then for each TF tensor argument, - // expect a TFD tensor argument instead. - TypeConverter::SignatureConversion new_func_sig(type.getNumInputs() + 1); - if (type.getNumInputs() > 0) { - // Add the first chain argument. - new_func_sig.addInputs(chain_type); - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) { - // For each original TF tensor type, convert it to one TFD tensor type. - if (auto tensor_type = type.getInput(i).dyn_cast()) { - new_func_sig.addInputs(i, {tfd_tensor_type}); - } else { - // Only handle TF tensor argument for now. - return failure(); - } - } - } - // Each function has a single region. In general, each region can have - // multiple blocks. Assume that all TF-dialect functions only have a - // single entry block. - Block* entry = &funcOp.front(); - - // Tell the rewriter to convert the region signature. After this, the - // function region takes the new function signature, which means index - // shifts by one. - Block* convertedEntry = - rewriter.applySignatureConversion(&funcOp.getBody(), new_func_sig); - - { - // Generate the "fake" mapping ops. The insertion guard restores rewriter - // insertion pointer when it gets out of scope. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(convertedEntry); - // Replace block arguments. For example, - // func @example(i64, i1) -> i64 { - // ^bb0(%a: i64, %cond: i1): // replacing this. - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) { - // For each original block argument, create a fake op that takes the - // input the input chain argument to the function, and the tfd tensor - // argument, and returns the original TF tensor input. Note that the - // function signature has been replaced, so entry->getArgument(0) is the - // input chain. And we need to add 1 to index to get the original - // argument. - Type orig_input = type.getInput(i); - OperationState tmp_lowering_cast_op( - funcOp.getLoc(), kTmpLoweringCastOpName, - {convertedEntry->getArgument(0), - convertedEntry->getArgument(i + 1)}, - orig_input, {}); - Value repl_value = - rewriter.createOperation(tmp_lowering_cast_op)->getResult(0); - // Replace original uses of TF tensor block argument with the result of - // the fake op. This sets up the lowering passes for individual ops - // which at this point still expect TF tensors rather than TFD tensor - // inputs. - rewriter.replaceUsesOfBlockArgument(entry->getArgument(i), repl_value); - } - } - - // Create a new function op with an updated signature. - auto new_func_op = rewriter.cloneWithoutRegions(funcOp); - rewriter.inlineRegionBefore(funcOp.getBody(), new_func_op.getBody(), - new_func_op.end()); - new_func_op.setType(FunctionType::get(new_func_sig.getConvertedTypes(), - converted_results, ctx)); - // Remove the old function op. - rewriter.eraseOp(funcOp); - return success(); - } -}; - -// Lower each TF op to a tfd.delegate_kernel op. For example, -// -// %1 = "tf.ReadVariableOp"(%arg) { -// dtype = "tfdtype$DT_FLOAT" -// } : (tensor<*x!tf.resource>) -> tensor<10xf32> -// -// would be lowered to -// -// %1:2 = "tfd.delegate_kernel"(%chain_in, %arg) { -// _name = "tf.ReadVariableOp", -// attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" -// } : (!hex.chain, !tfd.tf_tensor) -> (!hex.chain, !tfd.tf_tensor) -// -// Each tfd.delegate_kernel op expects a chain as the first input. This chain -// may come from the first function argument or the previous converted op -// output. The rest of inputs would be converted to a tfd tensor input. -// Each tfd.delegate_kernel op returns a chain as the first output. Each -// original output TensorType is converted a tfd tensor type. -// The TF op name becomes an _name attribute. Each TF attribute is lowered to -// two TFD attributes, one for the name, one for the type and value. -// -// Because delegate_kernel ops are threaded through chains, we lowered to a -// serial execution plan. -// TODO(zhangqiaorjc): Do analysis to allow concurrent execution. -template -class TFOpConversion : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - TF_OP op, llvm::ArrayRef operands, - ConversionPatternRewriter& rewriter) // NOLINT(google-runtime-references - const override { - auto ctx = op.getContext(); - // Handle new op operands. - // Delegate kernel expects the first argument to be a chain, followed by - // original arguments to the target TF op converted to TFD tensors. - llvm::SmallVector delegate_kernel_op_operands; - int num_new_operands = op.getOperation()->getNumOperands() + 1; - delegate_kernel_op_operands.reserve(num_new_operands); - - // Get the input chain from the previous delegate_kernel op or first block - // argument. - Value chain_input = nullptr; - auto* block = op.getOperation()->getBlock(); - assert(block->isEntryBlock() && "only supports a single block"); - // Find a previous delegate_kernel op for its output chain. - auto* prev_op = op.getOperation()->getPrevNode(); - while (prev_op != nullptr && !isa(prev_op)) { - prev_op = prev_op->getPrevNode(); - } - if (prev_op != nullptr) { - // There is another delegate kernel op before this op. - auto prev_op_result_0 = prev_op->getResult(0); - assert(prev_op_result_0.getType() == GetChainType(ctx)); - chain_input = prev_op_result_0; - } else { - // This op is the first delegate kernel op in a block. - auto arg_0 = block->getArgument(0); - assert(arg_0.getType() == GetChainType(ctx)); - chain_input = arg_0; - } - delegate_kernel_op_operands.push_back(chain_input); - - // Convert each TensorType operand to the corresponding TFD tensor operand. - for (auto operand : operands) { - auto* tmp_lowering_cast_op = operand.getDefiningOp(); - assert(tmp_lowering_cast_op->getName().getStringRef() == - kTmpLoweringCastOpName); - delegate_kernel_op_operands.push_back( - tmp_lowering_cast_op->getOperand(1)); - } - - // Handle new op results. - llvm::SmallVector delegate_kernel_op_results; - // The first output is a chain. - delegate_kernel_op_results.push_back(GetChainType(ctx)); - // For each original output, there is a corresponding TFD tensor output. - for (int i = 0, e = op.getOperation()->getNumResults(); i != e; ++i) { - delegate_kernel_op_results.push_back(GetTfdTensorType(ctx)); - } - - // Convert TF attribute to TFD attribute. - llvm::SmallVector delegate_kernel_op_attributes; - NamedAttribute op_name_attr(Identifier::get("_name", ctx), - StringAttr::get(op.getOperationName(), ctx)); - delegate_kernel_op_attributes.push_back(op_name_attr); - - int attr_idx = 0; - for (const NamedAttribute& tf_attr : op.getAttrs()) { - // Small std::string benefits from small string optimization in libc++. - NamedAttribute attr_name( - Identifier::get("attr" + std::to_string(attr_idx) + "_name", ctx), - StringAttr::get(tf_attr.first, ctx)); - NamedAttribute attr_value( - Identifier::get("attr" + std::to_string(attr_idx) + "_value", ctx), - tf_attr.second); - delegate_kernel_op_attributes.push_back(attr_name); - delegate_kernel_op_attributes.push_back(attr_value); - attr_idx++; - } - - // Replace the TF op with TFD delegate kernel op. - auto new_op = rewriter.create( - op.getLoc(), delegate_kernel_op_results, delegate_kernel_op_operands, - delegate_kernel_op_attributes); - - // Create lowering cast ops for non-chain results. - llvm::SmallVector lowering_cast_ops_values; - // Skip the first result. It's a chain which has no current users. - for (int i = 1, e = new_op.getOperation()->getNumResults(); i != e; ++i) { - Type orig_input = op.getType(); - OperationState tmp_lowering_cast_op(new_op.getLoc(), - kTmpLoweringCastOpName, - {new_op.getOperation()->getResult(0), - new_op.getOperation()->getResult(i)}, - {orig_input}, {}); - Value repl_value = - rewriter.createOperation(tmp_lowering_cast_op)->getResult(0); - lowering_cast_ops_values.push_back(repl_value); - } - - rewriter.replaceOp(op, lowering_cast_ops_values); - return success(); - } -}; - -class ReturnOpConversion : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - // Replace std.return with hex.return. The first result is always a chain and - // each original TF tensor result is converted to a TFD tensor. - LogicalResult matchAndRewrite( - ReturnOp return_op, llvm::ArrayRef operands, - ConversionPatternRewriter& rewriter) const override { - auto ctx = return_op.getContext(); - Value chain_output = nullptr; - llvm::SmallVector new_return_op_operands; - new_return_op_operands.reserve(return_op.getNumOperands() + 1); - // Convert each TF tensor operand to the corresponding TFD tensor operand. - for (auto operand : operands) { - auto* tmp_lowering_cast_op = operand.getDefiningOp(); - if (tmp_lowering_cast_op->getName().getStringRef() != - kTmpLoweringCastOpName) { - assert(false && "unexpected producer of operand"); - } - if (chain_output == nullptr) { - // Get the input chain from the previous op or first block argument. - auto* block = return_op.getOperation()->getBlock(); - if (!block->isEntryBlock()) { - assert(false && "only supports a single block"); - } - // Find a previous delegate_kernel op for its output chain. - auto* prev_op = return_op.getOperation()->getPrevNode(); - while (prev_op != nullptr && !isa(prev_op)) { - prev_op = prev_op->getPrevNode(); - } - if (prev_op != nullptr) { - // There is another delegate kernel op before this op. - auto prev_op_result_0 = prev_op->getResult(0); - if (prev_op_result_0.getType() != GetChainType(ctx)) { - assert(false && - "delegate kernel must produce chain as the first result"); - } - chain_output = prev_op_result_0; - } else { - // This op is the first delegate kernel op in a block. - auto arg_0 = block->getArgument(0); - if (arg_0.getType() != GetChainType(ctx)) { - assert(false && "first block argument must be a chain"); - } - chain_output = arg_0; - } - new_return_op_operands.push_back(chain_output); - } - new_return_op_operands.push_back(tmp_lowering_cast_op->getOperand(1)); - } - // Replace the old std.return op with the new hex.return op. - rewriter.create(return_op.getLoc(), - new_return_op_operands); - rewriter.eraseOp(return_op); - - return success(); - } -}; - -void TfToTfdLoweringPass::runOnOperation() { - ConversionTarget target(getContext()); - - // Make tmp_lowering_cast_op legal for conversion. But delete them after the - // passes. - OperationName tmp_lowering_cast_op_name(kTmpLoweringCastOpName, - &getContext()); - target.setOpAction(tmp_lowering_cast_op_name, - ConversionTarget::LegalizationAction::Legal); - - // target.addLegalDialect(); - target.addLegalDialect(); - - target.addDynamicallyLegalOp([](FuncOp function) { - // Returns true if this function is legal, i.e. all inputs and outputs are - // TFRT types. - FunctionType type = function.getType(); - for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) { - if (type.getInput(i).isa()) return false; - } - for (unsigned i = 0, e = type.getNumResults(); i != e; ++i) { - if (type.getResult(i).isa()) return false; - } - return true; - }); - - target.addLegalOp(); - - OwningRewritePatternList patterns; - patterns.insert, - TFOpConversion, TFOpConversion, - TFOpConversion, TFOpConversion, - ReturnOpConversion>(&getContext()); - - if (failed(applyPartialConversion(getOperation(), target, patterns))) - signalPassFailure(); - - // Delete the tmp_lowering_cast_op's since they are illegal. - getOperation().walk([&tmp_lowering_cast_op_name](Operation* op) { - if (op->getName() == tmp_lowering_cast_op_name) op->erase(); - }); -} - -} // namespace -} // namespace mlir - -static mlir::PassRegistration pass( - "tf-to-tfd-lowering", "Lowers the TF dialect to Runtime Fallback dialect."); diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_combine.cc b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_combine.cc deleted file mode 100644 index 4fd57af55cc..00000000000 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_combine.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -//===----------------------------------------------------------------------===// -// -// This file implements a set of simple combiners for optimizing operations in -// the Runtime Fallback dialect. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h" - -// This optimizes the following scenario: -// %tft0, %c2 = "tfd.move_dht_to_tft"(%dht0, %c1) -// : (!dht.host_tensor, !hex.chain) -> (!tfd.tf_tensor, !hex.chain) -// %dht1, %c3 = "tfd.convert_tft_to_dht"(%tft0, %c2) -// : (!tfd.tf_tensor, !hex.chain) -> (!dht.host_tensor, !hex.chain) -// some_op %dht1, %c3 -// -// becomes -// some_op %dht0, %c1 - -struct SimplifyDoubleConversion - : public mlir::OpRewritePattern { - // We register this pattern to match every tfd.move_dht_to_tft op. - // The "benefit" is used by the framework to order the patterns and process - // them in order of profitability. - explicit SimplifyDoubleConversion(mlir::MLIRContext* context) - : mlir::OpRewritePattern(context, - /*benefit=*/1) {} - - // This method attempts to match a pattern and rewrite it. The rewriter - // argument is the orchestrator of the sequence of rewrites. The pattern is - // expected to interact with it to perform any changes to the IR from here. - mlir::LogicalResult matchAndRewrite( - mlir::tfd::ConvertTftToDhtOp op, - mlir::PatternRewriter& rewriter) const override { - // Look through the inputs of the ConvertTftToDhtOp. - mlir::Value convert_op_input_0 = op.getOperand(0); - mlir::Value convert_op_input_1 = op.getOperand(1); - mlir::tfd::MoveDhtToTftOp move_input_op_0 = - llvm::dyn_cast_or_null( - convert_op_input_0.getDefiningOp()); - mlir::tfd::MoveDhtToTftOp move_input_op_1 = - llvm::dyn_cast_or_null( - convert_op_input_1.getDefiningOp()); - - // The inputs should be MoveDhtToTftOp. - if (!move_input_op_0 || !move_input_op_1) return mlir::failure(); - // Both inputs are the same MoveDhtToTftOp. - if (move_input_op_0 != move_input_op_1) return mlir::failure(); - - // Use the rewriter to replace the ConvertTftToDhtOp's users with the - // operands of MoveDhtToTftOp. - rewriter.replaceOp( - op, {move_input_op_0.getOperand(0), move_input_op_0.getOperand(1)}); - return mlir::success(); - } -}; - -// Register rewrite pattern as "canonicalization" patterns on the MoveDhtToTftOp -// so that they can be picked up by the Canonicalization framework. -void mlir::tfd::ConvertTftToDhtOp::getCanonicalizationPatterns( - OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); -} diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.cc b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.cc deleted file mode 100644 index 9c69154673b..00000000000 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h" - -namespace mlir { -namespace tfd { - -//===----------------------------------------------------------------------===// -// TfrtDelegate Dialect -//===----------------------------------------------------------------------===// - -RuntimeFallbackDialect::RuntimeFallbackDialect(MLIRContext *context) - : Dialect(/*name=*/"tfd", context) { - allowUnknownTypes(); - - allowUnknownOperations(); - - addOperations< -#define GET_OP_LIST -#include "tensorflow/compiler/mlir/tfrt/runtime_fallback_ops.cc.inc" - >(); -} - -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tfrt/runtime_fallback_ops.cc.inc" - -} // namespace tfd -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h deleted file mode 100644 index 009d565e40d..00000000000 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the operations used in the Runtime Fallback dialect. - -#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ -#define TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ - -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project - -namespace mlir { -namespace tfd { - -// Dialect for TFRT delegate operations. -class RuntimeFallbackDialect : public Dialect { - public: - explicit RuntimeFallbackDialect(MLIRContext* context); - static StringRef getDialectNamespace() { return "tfd"; } -}; - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tfrt/runtime_fallback_ops.h.inc" - -} // namespace tfd -} // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td deleted file mode 100644 index aeed800a1c3..00000000000 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the definition file for the Runtime Fallback Dialect. - -#ifdef TFRT_DELEGATE_DIALECT -#else -#define TFRT_DELEGATE_DIALECT - -include "tfrt/tfrt_op_base.td" -include "mlir/Interfaces/SideEffects.td" - -//===----------------------------------------------------------------------===// -// Type definitions -//===----------------------------------------------------------------------===// -def TfTensorType : OpaqueType<"tfd", "tf_tensor", "!tfd.tf_tensor type">; - -//===----------------------------------------------------------------------===// -// Runtime Fallback Dialect definitions -//===----------------------------------------------------------------------===// - -def RuntimeFallback_Dialect : Dialect { - let name = "tfd"; - - let description = [{ - The Runtime Fallback dialect. - - This dialect contains operations to run existing TF kernels on TFRT by - invoking TF Eager API. - }]; - - let cppNamespace = "tfd"; -} - -//===----------------------------------------------------------------------===// -// Runtime Fallback Dialect Ops definitions -//===----------------------------------------------------------------------===// - -// Base class for the operation in this dialect. -class RuntimeFallbackDialect_Op traits = []> : - Op { } - -def InitEagerContextOp : RuntimeFallbackDialect_Op<"init_eager_context"> { - let summary = "eager context initialization operation"; - let description = [{ - The "tfd.init_eager_context" operation takes an input chain, creates and - initializes the TF EagerContext and returns an output chain. - - Example: - %c1 = "tfd.init_eager_context"(%c0): (!hex.chain) -> !hex.chain - }]; - - let arguments = (ins ChainType); - let results = (outs ChainType); -} - -def DelegateKernelOp : RuntimeFallbackDialect_Op<"delegate_kernel"> { - let summary = "delegate kernel operation"; - let description = [{ - The "tfd.delegate_kernel" operation takes an input chain, and arbitrary - number of input arguments, and runs a specified TF op via TFE C API. It - returns an output chain and variable number of outputs from the TF op. - - The input arguments and attributes are passed to the TF op. The ouputs are - outputs of the TF op. - - Note that `_name` is a required attribute specifying the TF op to run. - TFRT attributes are sorted alphabetically, passed in as positional - attributes to the TFRT kernel, rather than as named attributes. - - Example: - To run "tf.MatMul" op, which has two boolean attributes, - 1. Set _name = "MatMul" - 2. For each TF attribute, split it into two attributes, one for name of - the TF attribute, and the other for the type and value of the - attribute value. Attribute value is a string with the format of - "type$val", where type can be "bool", "string", "tfdtype", "tfshape", - "tftensor". - The value serialization format can be found in attr_util.h. - - %out_c, %out_tensor = "tfd.delegate_kernel"( - %in_c, %in1_tensor, %in2_tensor) { - _name = "MatMul", - attr1_name = "transpose_a", attr1_value = "bool$false", - attr2_name = "transpose_b", attr2_value = "bool$false" - } : (!hex.chain, !tfd.tf_tensor, !tfd.tf_tensor) -> ( - !hex.chain, !tfd.tf_tensor) - }]; - - let arguments = (ins ChainType, Variadic); - let results = (outs ChainType, Variadic); -} - -def PrintTftOp : RuntimeFallbackDialect_Op<"print_tft"> { - let summary = "print TF tensor operation"; - let description = [{ - The "tfd.print_tft" operation prints the input TF tensor. It takes an input - TF tensor to be printed and an input chain, and returns an output chain. - - Example: - %c1 = "tfd.print_tft"(%t, %c) : (!tfd.tf_tensor, !hex.chain) -> !hex.chain - - }]; - - let arguments = (ins TfTensorType, ChainType); - let results = (outs ChainType); -} - -def ConvertTftToDhtOp : RuntimeFallbackDialect_Op<"convert_tft_to_dht", [NoSideEffect]> { - let summary = "convert TF tensor to TFRT DHT tensor operation"; - let description = [{ - The "tfd.convert_tft_to_dht" operation converts a TF tensor to a TFRT - DenseHostTensor. - - It takes as input a TF Tensor and an input chain, and returns a converted - TFRT DHT tensor and an output chain. - - Example: - %dht, %c0 = "tfd.convert_tft_to_dht"(%tft, %c) - : (!tfd.tf_tensor, !hex.chain) -> (!dht.host_tensor, !hex.chain) - }]; - - let arguments = (ins TfTensorType, ChainType); - // Enable registering canonicalization patterns with this operation. - let hasCanonicalizer = 1; - let results = (outs TensorType, ChainType); -} - -def MoveDhtToTftOp : RuntimeFallbackDialect_Op<"move_dht_to_tft", [NoSideEffect]> { - let summary = "convert TFRT DHT tensor to DHT tensor operation"; - let description = [{ - The "tfd.move_dht_to_tft" operation moves a TFRT tensor into a TF Tensor. - - It takes as input a TFRT Tensor and an input chain, and returns a TF tensor - with the same underlying buffer and an output chain. - - Example: - %dht, %c0 = "tfd.convert_tft_to_dht"(%tft, %c) - : (!tfd.tf_tensor, !hex.chain) -> (!dht.host_tensor, !hex.chain) - }]; - - let arguments = (ins TensorType, ChainType); - let results = (outs TfTensorType, ChainType); -} - -#endif // TFRT_DELEGATE_DIALECT diff --git a/tensorflow/compiler/mlir/tfrt/tests/BUILD b/tensorflow/compiler/mlir/tfrt/tests/BUILD deleted file mode 100644 index 4faa8d2efe8..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") - -package(licenses = ["notice"]) - -glob_lit_tests( - data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//tensorflow/compiler/mlir:tf-opt", - "@llvm-project//llvm:FileCheck", - ], -) diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD b/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD deleted file mode 100644 index fc7c142ea73..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") - -package(licenses = ["notice"]) - -glob_lit_tests( - data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = ["mlir"], -) - -# Bundle together all of the test utilities that are used by tests. -filegroup( - name = "test_utilities", - testonly = True, - data = [ - "//tensorflow/compiler/mlir:tf-mlir-translate", - "@llvm-project//llvm:FileCheck", - ], -) diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/compatibility_analysis.mlir b/tensorflow/compiler/mlir/tfrt/tests/analysis/compatibility_analysis.mlir deleted file mode 100644 index 5943997a1bc..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/analysis/compatibility_analysis.mlir +++ /dev/null @@ -1,65 +0,0 @@ -// RUN: tf-mlir-translate -analyze-tf-for-tfrt %s | FileCheck %s - -func @main(%serialized: tensor<32x!tf.string>, - %names : tensor<32x!tf.string>, - %dense_keys : tensor<2x!tf.string>, - %dense_default_0 : tensor, - %dense_default_1 : tensor) { - // CHECK: summary { - // CHECK-NEXT: ref_variable: true - // CHECK-NEXT: incompatible_variable: true - // CHECK-NEXT: } - // CHECK-NEXT: ops { - // CHECK-NEXT: key: "tf.AssignVariableOp" - // CHECK-NEXT: value { - // CHECK-NEXT: count: 1 - // CHECK-NEXT: report { - // CHECK-NEXT: incompatible_variable: true - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: ops { - // CHECK-NEXT: key: "tf.Const" - // CHECK-NEXT: value { - // CHECK-NEXT: count: 2 - // CHECK-NEXT: report { - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: ops { - // CHECK-NEXT: key: "tf.ParseExampleV2" - // CHECK-NEXT: value { - // CHECK-NEXT: count: 1 - // CHECK-NEXT: report { - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: ops { - // CHECK-NEXT: key: "tf.VarHandleOp" - // CHECK-NEXT: value { - // CHECK-NEXT: count: 1 - // CHECK-NEXT: report { - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: ops { - // CHECK-NEXT: key: "tf.VariableV2" - // CHECK-NEXT: value { - // CHECK-NEXT: count: 1 - // CHECK-NEXT: report { - // CHECK-NEXT: ref_variable: true - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - %0 = "tf.VariableV2"() {shape = #tf.shape<2>, container = "", shared_name = ""} : () -> tensor - %1 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor - %2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - "tf.AssignVariableOp"(%2, %1) : (tensor<*x!tf.resource>, tensor) -> () - %empty_str_vector = "tf.Const"() - {dtype = !tf.string, value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B207D207D"> : tensor<0x!tf.string>} - : () -> tensor<0x!tf.string> - %result:2 = "tf.ParseExampleV2"(%serialized, %names, %empty_str_vector, %dense_keys, %empty_str_vector, %dense_default_0, %dense_default_1) - {dense_shapes = [#tf.shape<>, #tf.shape<>], num_sparse = 0 : i64, result_segment_sizes = dense<[0, 0, 0, 2, 0, 0]> : vector<6xi32>} - : (tensor<32x!tf.string>, tensor<32x!tf.string>, tensor<0x!tf.string>, tensor<2x!tf.string>, tensor<0x!tf.string>, tensor, tensor) -> (tensor<32xf32>, tensor<32xf32>) - return -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/basics.mlir b/tensorflow/compiler/mlir/tfrt/tests/basics.mlir deleted file mode 100644 index 650bd04b882..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/basics.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: tf-opt -tf-legalize-to-hex %s -o -| FileCheck %s - - -// CHECK-LABEL: func @constants() { -func @constants() { - // CHECK: "hex.constant_int"() {value = 1 : i32} - %0 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "x", value = dense<1> : tensor} : () -> tensor - // CHECK: "hex.constant_int"() {value = 42 : i32} - %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "y", value = dense<42> : tensor<1x1xi32>} : () -> tensor<1x1xi32> - // CHECK: hex.return - return -} - -// CHECK-LABEL: func @add -func @add(%arg0: tensor<1xi32>) { - // CHECK: hex.add_int - %2 = "tf.Add"(%arg0, %arg0) {T = "tfdtype$DT_INT32", device = "", name = "z"} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - return -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/err_partial_convert.mlir b/tensorflow/compiler/mlir/tfrt/tests/err_partial_convert.mlir deleted file mode 100644 index 410ff299883..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/err_partial_convert.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: tf-opt %s -tf-legalize-to-hex -verify-diagnostics - -func @partial_convert() { - %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // expected-error @+1 {{failed to legalize operation 'tf.Const'}} - %1 = "tf.Const"() {value = dense<42> : tensor<2xi32>} : () -> tensor<2xi32> - %2 = "tf.Add"(%0, %1) : (tensor, tensor<2xi32>) -> tensor<2xi32> - return -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/opt.mlir b/tensorflow/compiler/mlir/tfrt/tests/opt.mlir deleted file mode 100644 index 6f27fa6d7e4..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/opt.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: tf-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s - -// CHECK-LABEL: func @simplify_double_conversion_test( -func @simplify_double_conversion_test() { - // CHECK: %[[CREATE:.*]] = dht.create - // CHECK: %[[FILL:.*]] = dht.fill - // CHECK: dht.print_tensor %[[CREATE]], %[[FILL]] - %c0 = hex.new.chain - - // Create 2x2 dht with value 1 - %dht0 = dht.create_uninitialized_tensor.i32.2 [2 : i32, 2 : i32] - %c1 = dht.fill_tensor_with_constant.i32 %dht0, %c0 1 : i32 - - // Convert dht to tf tensor - %tft0, %c2 = "tfd.move_dht_to_tft"(%dht0, %c1) - : (!t.tensor, !hex.chain) -> (!tfd.tf_tensor, !hex.chain) - - // Convert tf tensor back to dht - %dht1, %c3 = "tfd.convert_tft_to_dht"(%tft0, %c2) - : (!tfd.tf_tensor, !hex.chain) -> (!t.tensor, !hex.chain) - - // Print the result dht - %c4 = dht.print_tensor %dht1, %c3 - - hex.return -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir deleted file mode 100644 index 6c129c4be22..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: tf-opt -tf-to-corert %s | FileCheck %s - -module attributes {tf_saved_model.semantics} { - -"tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = dense<[[1.67482901, -0.529208779, -0.803792417]]> : tensor<1x3xf32>} : () -> () - -// CHECK-LABEL: func @basic -func @func_basic( - %arg0: tensor<3x1xf32> {tf_saved_model.index_path = [0]}, - %arg1: tensor>> {tf_saved_model.bound_input = @y}) - -> (tensor<3x3xf32> {tf_saved_model.index_path = []}) - attributes {tf_saved_model.exported_names = ["basic"]} { - %1 = "tf.ReadVariableOp"(%arg1) {_output_shapes = ["tfshape$dim { size: 1 } dim { size: 3 }"], device = "cpu", dtype = f32} : (tensor>>) -> tensor<1x3xf32> - - // CHECK: {{%.*}} = corert.executeop({{%.*}}) "tf.MatMul" - // CHECK-SAME: {T = f32, transpose_a = false, transpose_b = false} - %2 = "tf.MatMul"(%arg0, %1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> - return %2 : tensor<3x3xf32> -} - -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir deleted file mode 100644 index 40b0332b61c..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: tf-opt -tf-to-corert %s | FileCheck %s - -// CHECK-NOT: tf_saved_model.semantics -module attributes {tf_saved_model.semantics} { - -// CHECK-NOT: "tf_saved_model.global_tensor" -"tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = dense<[[1.67482901, -0.529208779, -0.803792417]]> : tensor<1x3xf32>} : () -> () -"tf_saved_model.global_tensor"() {is_mutable, sym_name = "z", type = tensor<3xf32>, value = dense<[1.67482901, -0.529208779, -0.803792417]> : tensor<3xf32>} : () -> () - -// CHECK-LABEL: func @basic -// CHECK-SAME: ([[arg0:%.*]]: !corert.tensorhandle, [[arg1:%.*]]: !corert.tensorhandle, -// CHECK-SAME: [[arg2:%.*]]: !corert.tensorhandle) -> !corert.tensorhandle { -func @func_basic( - %arg0: tensor<3x1xf32> {tf_saved_model.index_path = [0]}, - %arg1: tensor>> {tf_saved_model.bound_input = @y}, - %arg2: tensor>> {tf_saved_model.bound_input = @z}) - -> (tensor<3x3xf32> {tf_saved_model.index_path = []}) - attributes {tf_saved_model.exported_names = ["basic"]} { - // CHECK-NEXT: [[cpu_device:%.*]] = corert.get_device "cpu" - // CHECK-NEXT: [[r0:%.*]] = corert.executeop([[cpu_device]]) "tf.MatMul"([[arg0]], [[arg1]]) - // CHECK-NEXT: [[r1:%.*]] = corert.executeop([[cpu_device]]) "tf.BiasAdd"([[r0]], [[arg2]]) - // CHECK-NEXT: [[r2:%.*]] = corert.executeop([[cpu_device]]) "tf.Tanh"([[r1]]) - // CHECK-NEXT: hex.return [[r2]] : !corert.tensorhandle - - %0 = "tf.ReadVariableOp"(%arg2) {_output_shapes = ["tfshape$dim { size: 3 }"], device = "cpu", dtype = f32} : (tensor>>) -> tensor<3xf32> - %1 = "tf.ReadVariableOp"(%arg1) {_output_shapes = ["tfshape$dim { size: 1 } dim { size: 3 }"], device = "cpu", dtype = f32} : (tensor>>) -> tensor<1x3xf32> - %2 = "tf.MatMul"(%arg0, %1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> - %3 = "tf.BiasAdd"(%2, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], data_format = "NHWC", device = "cpu"} : (tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32> - %4 = "tf.Tanh"(%3) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32> - %5 = "tf.Identity"(%4) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32> - return %5 : tensor<3x3xf32> -} - -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/derived_attrs.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/derived_attrs.mlir deleted file mode 100644 index 774ea0526bd..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/derived_attrs.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: tf-opt -tf-to-corert %s | FileCheck %s - -// CHECK-LABEL: func @derived_attrs -func @derived_attrs( - %serialized: tensor, - %names: tensor<0x!tf.string>, - %sparse_keys: tensor<0x!tf.string>, - %dense_keys: tensor<1x!tf.string>, - %ragged_keys: tensor<0x!tf.string>, - %dense_default: tensor<0xi64>) -> tensor { - - %dense_value = - "tf.ParseExampleV2"(%serialized, %names, %sparse_keys, %dense_keys, %ragged_keys, %dense_default) - // CHECK: Tdense = [i64] - // CHECK-SAME: dense_shapes = [#corert.shape<>] - { device = "cpu", num_sparse = 0 : i64, dense_shapes = [#tf.shape<>], result_segment_sizes = dense<[0, 0, 0, 1, 0, 0]> : vector<6xi32>} - : (tensor, tensor<0x!tf.string>, tensor<0x!tf.string>, tensor<1x!tf.string>, tensor<0x!tf.string>, tensor<0xi64>) - -> tensor - - return %dense_value : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir deleted file mode 100644 index 7077523b1e2..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: tf-opt -tf-to-corert %s | FileCheck %s - -// CHECK-LABEL: func @device_test -func @device_test( - %arg0: tensor<3x1xf32> {tf_saved_model.index_path = [0]}, - %arg1: tensor<1x3xf32> {tf_saved_model.index_path = [0]}) - -> (tensor<3x3xf32> {tf_saved_model.index_path = []}) { - // CHECK: {{%.*}} = corert.get_device "gpu" - - %2 = "tf.MatMul"(%arg0, %arg1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "gpu", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> - return %2 : tensor<3x3xf32> -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fold.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fold.mlir deleted file mode 100644 index 950cef928a9..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fold.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: tf-opt -corert-optimize %s | FileCheck %s - -// CHECK-LABEL: func @fold_test -func @fold_test(%arg: !corert.tensorhandle) -> !corert.tensorhandle { - %cpu = corert.get_device "cpu" - // CHECK-NOT: tf.Const - %0 = corert.executeop(%cpu) "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : 1 - // CHECK: "_tf.Transpose"({{%.*}}) - // CHECK-SAME: perm = dense<[0, 3, 1, 2]> : tensor<4xi32> - %1 = corert.executeop(%cpu) "tf.Transpose"(%arg, %0) {T = f32, Tperm = i32} : 1 - hex.return %1 : !corert.tensorhandle -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/string_tensor.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/string_tensor.mlir deleted file mode 100644 index b1306be825c..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/string_tensor.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: tf-opt -tf-to-corert %s | FileCheck %s - -// CHECK-LABEL: func @string_tensor -func @string_tensor() -> (tensor<0x!tf.string>, tensor<7x!tf.string>) { - // CHECK: {shape = [0], value = []} - %0 = "tf.Const"() {value = dense<[]> : tensor<0x!tf.string>} : () -> tensor<0x!tf.string> - // CHECK: {shape = [7], value = ["has_login_page_feature", "num_terms_inside_postform", "num_terms_outside_postform", "num_terms_outside_postform_without_bp", "query_params_contains_url", "title_with_login_phase", "url_contains_login_terms"]} - %1 = "tf.Const"() {value = dense<["has_login_page_feature", "num_terms_inside_postform", "num_terms_outside_postform", "num_terms_outside_postform_without_bp", "query_params_contains_url", "title_with_login_phase", "url_contains_login_terms"]> : tensor<7x!tf.string>} : () -> tensor<7x!tf.string> - return %0, %1 : tensor<0x!tf.string>, tensor<7x!tf.string> -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_executor_to_corert_pipeline.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_executor_to_corert_pipeline.mlir deleted file mode 100644 index 5c44f558280..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_executor_to_corert_pipeline.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: tf-opt -tf-executor-to-corert-pipeline %s | FileCheck %s - -// CHECK-LABEL: func @basic -// CHECK-SAME: ([[arg0:%.*]]: !corert.tensorhandle, [[arg1:%.*]]: !corert.tensorhandle) -// CHECK-NEXT: [[cpu:%.*]] = corert.get_device "cpu" -// CHECK-NEXT: [[res:%.*]] = corert.executeop([[cpu]]) "tf.MatMul"([[arg0]], [[arg1]]) -// CHECK-NEXT: hex.return [[res]] : !corert.tensorhandle -module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 293 : i32}} { - func @basic(%arg0: tensor<3x1xf32>, - %arg1: tensor>> - ) -> tensor<3x3xf32> { - %0 = tf_executor.graph { - %outputs, %control = tf_executor.island wraps "tf.Const"() {value = dense<0.899999976> : tensor} : () -> tensor - %outputs_0, %control_0 = tf_executor.island { - %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor>>) -> tensor<*x!tf.resource> - %2 = "tf.ReadVariableOp"(%1) {_output_shapes = ["tfshape$dim { size: 1 } dim { size: 3 }"], device = "", dtype = f32} : (tensor<*x!tf.resource>) -> tensor<1x3xf32> - %3 = "tf.MatMul"(%arg0, %2) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> - tf_executor.yield %3 : tensor<3x3xf32> - } - tf_executor.fetch %outputs_0, %control_0 : tensor<3x3xf32>, !tf_executor.control - } - return %0 : tensor<3x3xf32> - } -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfd_lowering.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfd_lowering.mlir deleted file mode 100644 index 5968a590f91..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfd_lowering.mlir +++ /dev/null @@ -1,111 +0,0 @@ -// RUN: tf-opt %s -tf-to-tfd-lowering | FileCheck %s - -// CHECK: func @inference_call( -// CHECK-SAME: %arg0: !hex.chain, -// CHECK-SAME: %arg1: !tfd.tf_tensor, -// CHECK-SAME: %arg2: !tfd.tf_tensor, -// CHECK-SAME: %arg3: !tfd.tf_tensor, -// CHECK-SAME: %arg4: !tfd.tf_tensor, -// CHECK-SAME: %arg5: !tfd.tf_tensor -// CHECK-SAME: ) -> (!hex.chain, !tfd.tf_tensor) -func @inference_call( - %arg0: tensor, - %arg1: tensor<*x!tf.resource>, - %arg2: tensor<*x!tf.resource>, - %arg3: tensor<*x!tf.resource>, - %arg4: tensor<*x!tf.resource> - )-> tensor { - // CHECK: %0:2 = "tfd.delegate_kernel"(%arg0, %arg5) - // CHECK-SAME: _name = "tf.ReadVariableOp" - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" - // CHECK-SAME: (!hex.chain, !tfd.tf_tensor) -> (!hex.chain, !tfd.tf_tensor) - %0 = "tf.ReadVariableOp"(%arg4) { - dtype = "tfdtype$DT_FLOAT" - } : (tensor<*x!tf.resource>) -> tensor<10xf32> - - // CHECK: %1:2 = "tfd.delegate_kernel"(%0#0, %arg3) { - // CHECK-SAME: _name = "tf.ReadVariableOp" - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %1 = "tf.ReadVariableOp"(%arg2) { - dtype = "tfdtype$DT_FLOAT" - } : (tensor<*x!tf.resource>) -> tensor<512xf32> - - // CHECK: %2:2 = "tfd.delegate_kernel"(%1#0, %arg4) { - // CHECK-SAME: _name = "tf.ReadVariableOp", - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %2 = "tf.ReadVariableOp"(%arg3) { - dtype = "tfdtype$DT_FLOAT" - } : (tensor<*x!tf.resource>) -> tensor<512x10xf32> - - // CHECK: %3:2 = "tfd.delegate_kernel"(%2#0, %arg2) { - // CHECK-SAME: _name = "tf.ReadVariableOp", - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %3 = "tf.ReadVariableOp"(%arg1) { - dtype = "tfdtype$DT_FLOAT" - } : (tensor<*x!tf.resource>) -> tensor<784x512xf32> - - // CHECK: %4:2 = "tfd.delegate_kernel"(%3#0, %arg1, %3#1) { - // CHECK-SAME: _name = "tf.MatMul", - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT", - // CHECK-SAME: attr1_name = "transpose_a", attr1_value = false, - // CHECK-SAME: attr2_name = "transpose_b", attr2_value = false - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %4 = "tf.MatMul"(%arg0, %3) { - dtype = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false - } : (tensor, tensor<784x512xf32>) -> tensor - - // CHECK: %5:2 = "tfd.delegate_kernel"(%4#0, %4#1, %1#1) { - // CHECK-SAME: _name = "tf.AddV2" - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %5 = "tf.AddV2"(%4, %1) - : (tensor, tensor<512xf32>)-> tensor - - // CHECK: %6:2 = "tfd.delegate_kernel"(%5#0, %5#1) { - // CHECK-SAME: _name = "tf.Relu", - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %6 = "tf.Relu"(%5) { - dtype = "tfdtype$DT_FLOAT" - } : (tensor) -> tensor - - // CHECK: %7:2 = "tfd.delegate_kernel"(%6#0, %6#1, %2#1) { - // CHECK-SAME: _name = "tf.MatMul", - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT", - // CHECK-SAME: attr1_name = "transpose_a", attr1_value = false, - // CHECK-SAME: attr2_name = "transpose_b", attr2_value = false - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %7 = "tf.MatMul"(%6, %2) { - dtype = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false - } : (tensor, tensor<512x10xf32>) -> tensor - - // CHECK: %8:2 = "tfd.delegate_kernel"(%7#0, %7#1, %0#1) { - // CHECK-SAME: _name = "tf.AddV2", - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %8 = "tf.AddV2"(%7, %0) { - dtype = "tfdtype$DT_FLOAT" - } : (tensor, tensor<10xf32>) -> tensor - - // CHECK: %9:2 = "tfd.delegate_kernel"(%8#0, %8#1) { - // CHECK-SAME: _name = "tf.Identity", - // CHECK-SAME: attr0_name = "dtype", attr0_value = "tfdtype$DT_FLOAT" - // CHECK-SAME: } : (!hex.chain, !tfd.tf_tensor) - // CHECK-SAME: -> (!hex.chain, !tfd.tf_tensor) - %9 = "tf.Identity"(%8) { - dtype = "tfdtype$DT_FLOAT" - } : (tensor) -> tensor - - // CHECK: hex.return %9#0, %9#1 : !hex.chain, !tfd.tf_tensor - return %9 : tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tf_legalize_to_hex.cc b/tensorflow/compiler/mlir/tfrt/tf_legalize_to_hex.cc deleted file mode 100644 index 9d13955490b..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tf_legalize_to_hex.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements lowering of Tf dialect to TFRT Hex kernels. -// -// Current lowering is a placeholder performing trivial conversion -// for integer constants and additions. - -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "absl/memory/memory.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" - -namespace mlir { -namespace { - -// Pattern rewrite rules for "tf.Const", "tf.Add" and "return" ops. -bool isInt32LikeType(Type t) { - if (t.isSignlessInteger(32)) return true; - if (auto ttype = t.dyn_cast()) { - if (ttype.hasStaticShape() && ttype.getNumElements() == 1 && - ttype.getElementType().isSignlessInteger(32)) - return true; - } - return false; -} - -// Replaces 32-bit integer TF::ConstOp with "hex.constant_int" op. -struct ConstOpConversion : public ConversionPattern { - explicit ConstOpConversion(MLIRContext *context) - : ConversionPattern(TF::ConstOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto constOp = cast(op); - if (!isInt32LikeType(constOp.getType())) return failure(); - - auto valueAttr = constOp.value(); - auto newAttr = Attribute(); - - // Convert constant op if it has an integer or dense elements attribute. - // Other kinds of element attributes are not converted for now. - if (valueAttr.isa()) { - newAttr = valueAttr; - } else if (auto v = valueAttr.dyn_cast()) { - if (v.isSplat()) newAttr = v.getSplatValue(); - } - if (!newAttr) return failure(); - - mlir::OperationState state(constOp.getLoc(), "hex.constant_int"); - state.types.push_back(rewriter.getIntegerType(32)); - state.addAttribute("value", newAttr); - auto newOp = rewriter.createOperation(state); - rewriter.replaceOp(op, newOp->getResult(0)); - return success(); - } -}; - -// Replaces 32-bit integer TF::Add op with "hex.add_int" op. -struct AddOpConversion : public ConversionPattern { - explicit AddOpConversion(MLIRContext *context) - : ConversionPattern(TF::AddOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto addOp = cast(op); - - if (!isInt32LikeType(operands[0].getType()) || - !isInt32LikeType(operands[1].getType())) - return failure(); - - auto int32Ty = rewriter.getIntegerType(32); - mlir::OperationState state(addOp.getLoc(), "hex.add_int", operands, - {int32Ty}, {}); - auto newOp = rewriter.createOperation(state); - rewriter.replaceOp(op, newOp->getResult(0)); - return success(); - } -}; - -// Replaces return op that has no arguments with "hex.return" op. -struct ReturnOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - ReturnOp srcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (srcOp.getNumOperands() != 0) return failure(); - - mlir::OperationState state(srcOp.getLoc(), "hex.return"); - rewriter.createOperation(state); - - rewriter.eraseOp(srcOp); - return success(); - } -}; - -// Legalize TF operations to host program dialect. -struct TfLegalizeToHex - : public PassWrapper> { - void runOnOperation() override { - auto *ctx = &getContext(); - TypeConverter converter; - converter.addConversion([](Type type) -> Type { - // Convert single element tensor type of int32s to int32 type - if (isInt32LikeType(type)) { - return IntegerType::get(32, type.getContext()); - } - return Type(); - }); - - OwningRewritePatternList patterns; - - // For now, replace only int32 TF::OpConst, TF::OpAdd and OpReturn with - // "hex.constant_int", "hex.add_int" and "hex.return", respectively. - patterns.insert( - ctx); - - ConversionTarget target(*ctx); - const auto legal = ConversionTarget::LegalizationAction::Legal; - target.setOpAction(OperationName(StringRef("hex.constant_int"), ctx), - legal); - target.setOpAction(OperationName(StringRef("hex.add_int"), ctx), legal); - target.setOpAction(OperationName(StringRef("hex.return"), ctx), legal); - target.addLegalOp(); - - auto result = - applyFullConversion(getOperation(), target, patterns, &converter); - if (failed(result)) signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> createLegalizeToHexPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "tf-legalize-to-hex", - "Convert TF dialect to the TF runtime host program dialect."); -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc deleted file mode 100644 index 9e06ba1f4bc..00000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements the optimzation passe on TFRT CoreRuntime dialect. -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" -#include "tfrt/core_runtime/opdefs/core_runtime.h" - -namespace tensorflow { -namespace { - -// Implement a constant fold pattern for corert dialect. The following pattern -// will be matched: -// -// %0 = corert.executeop(%cpu) "tf.Const"() -// {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : 1 -// %1 = corert.executeop(%cpu) "tf.Transpose"(%arg, %0) -// {T = f32, Tperm = i32} : 1 -// -// And it will converted to: -// -// %1 = corert.executeop(%cpu) "_tf.Transpose"(%arg) -// {T = f32, Tperm = i32, perm = dense<[0, 3, 1, 2]> : tensor<4xi32>} : 1 -// -class CoreRTExecuteOpRewritePattern - : public mlir::OpRewritePattern { - public: - CoreRTExecuteOpRewritePattern( - mlir::MLIRContext *context, - ArrayRef>> ops_to_attrs) - : OpRewritePattern(context), - ops_to_attrs_(ops_to_attrs.begin(), ops_to_attrs.end()) {} - - mlir::LogicalResult matchAndRewrite( - tfrt::corert::ExecuteOp op, - mlir::PatternRewriter &rewriter) const override { - auto attr_names = ops_to_attrs_.lookup(op.op_name()); - if (attr_names.empty()) return failure(); - - SmallVector new_operands; - SmallVector, 4> new_attributes; - op.getOpAttrs(&new_attributes); - assert(op.operands().size() == attr_names.size()); - for (const auto &iter : llvm::zip(op.operands(), attr_names)) { - mlir::Value arg = std::get<0>(iter); - StringRef name = std::get<1>(iter); - - Attribute const_attr; - if (!name.empty() && matchPattern(arg, m_Constant(&const_attr))) { - // Convert the folded argument to an attribute. - new_attributes.push_back({name, const_attr}); - } else { - // Keep the argument that is not folded. - new_operands.push_back(arg); - } - } - - if (new_operands.size() == op.operands().size()) return failure(); - - SmallString<32> new_op_name{"_"}; - new_op_name += op.op_name(); - - rewriter.replaceOpWithNewOp( - op, op.getResultTypes(), op.device(), new_operands, new_attributes, - new_op_name); - - return success(); - } - - private: - // Map from op_name to attr_names. The attr_names indicates the name of the - // attribute to which each constant-folded argument is converted. An empty - // string means this argument should not be folded. - llvm::DenseMap> ops_to_attrs_; -}; - -struct CoreRTOptimizePass - : public mlir::PassWrapper { - void runOnFunction() override { - mlir::OwningRewritePatternList patterns; - auto func = getFunction(); - - static constexpr StringRef kMeanAttrs[] = {"", "reduction_indices"}; - static constexpr StringRef kPadAttrs[] = {"", "paddings"}; - static constexpr StringRef kTransposeAttrs[] = {"", "perm"}; - - static constexpr std::pair> kOpsToAttrs[] = { - {"tf.Mean", kMeanAttrs}, - {"tf.Pad", kPadAttrs}, - {"tf.Transpose", kTransposeAttrs}, - }; - - patterns.insert(&getContext(), kOpsToAttrs); - - mlir::applyPatternsAndFoldGreedily(func, patterns); - } -}; - -} // namespace - -std::unique_ptr> CreateCoreRTOptimizePass() { - return std::make_unique(); -} - -static mlir::PassRegistration pass("corert-optimize", - "Optimizes corert."); - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/passes.h deleted file mode 100644 index be0bf0fbd1f..00000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_ - -#include - -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project - -namespace tensorflow { - -// Create a pass that converts MLIR TF dialect to MLIR TFRT CoreRT dialect. -std::unique_ptr CreateTFToCoreRTConversionPass(); - -// Run TFToCoreRTConversionPass as a free function. Useful for reusing the pass -// logic in a custom pass with additional conversions. -mlir::LogicalResult TFToCoreRTConversionPassRun( - mlir::MLIRContext* context, mlir::ModuleOp* module, - mlir::ConversionTarget* target, mlir::OwningRewritePatternList* patterns); - -// Create the corert optimization pass. -std::unique_ptr> CreateCoreRTOptimizePass(); - -struct CoreRTPipelineOptions - : public mlir::PassPipelineOptions { - Option default_device{ - *this, "default-device", llvm::cl::desc("default device assignment"), - llvm::cl::init("cpu")}; - Option enable_optimizer{ - *this, "enable-optimizer", - llvm::cl::desc("run optimization passes on corert dialect"), - llvm::cl::init(false)}; - Option force_data_format{ - *this, "force-data-format", - llvm::cl::desc("force data format for all layout sensitive operations")}; -}; - -// Creates a pipeline of passes that lowers MLIR TF Executor dialect to TF -// dialect for CoreRT purposes. -void CreateTFExecutorToTFPipeline( - mlir::OpPassManager& pm, const CoreRTPipelineOptions& options); // NOLINT - -// Creates a pipeline of passes that converts MLIR TF Executor dialect to CoreRT -// dialect. -void CreateTFExecutorToCoreRTPipeline( - mlir::OpPassManager& pm, const CoreRTPipelineOptions& options); // NOLINT - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_corert.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_corert.cc deleted file mode 100644 index 0784dc4ffea..00000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_corert.cc +++ /dev/null @@ -1,484 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements lowering of TF dialect to TFRT CoreRuntime ExecuteOp. -// This lowering pass is heavily experimental and incomplete. External code -// should not depend on the code here. And please do not take example on it as -// "the path forward" for this. - -#include - -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Pass/PassOptions.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/tstring.h" -#include "tfrt/basic_kernels/opdefs/basic_kernels.h" -#include "tfrt/core_runtime/opdefs/attributes.h" -#include "tfrt/core_runtime/opdefs/core_runtime.h" - -namespace tensorflow { -namespace { - -// TODO(chky): define these dialect types instead of using opaque types. -mlir::Type CreateDeviceType(mlir::Builder *builder) { - return mlir::OpaqueType::get(builder->getIdentifier("corert"), "device", - builder->getContext()); -} - -mlir::Type CreateTensorHandleType(mlir::Builder *builder) { - return mlir::OpaqueType::get(builder->getIdentifier("corert"), "tensorhandle", - builder->getContext()); -} - -mlir::Type CreateStringType(mlir::Builder *builder) { - return mlir::OpaqueType::get(builder->getIdentifier("hex"), "string", - builder->getContext()); -} - -// A helper class for converting CoreRT types and attributes. -class CoreRTConverter : public mlir::TypeConverter { - public: - explicit CoreRTConverter(mlir::MLIRContext *context) - : builder_(context), - device_type_(CreateDeviceType(&builder_)), - tensor_handle_type_(CreateTensorHandleType(&builder_)) { - addConversion([](Type type) { return type; }); - addConversion([=](TensorType type) { return tensor_handle_type_; }); - } - - // Create a single attribute that contains the named attribute lists. It is an - // array of pairs. The key must be a string attribute, and the value can be - // any attribute that is supported by CoreRuntime. - mlir::ArrayAttr CreateOpAttrs(ArrayRef attrs) { - llvm::SmallVector attr_array; - for (auto key_and_value : attrs) { - if (!IsUnusedAttribute(key_and_value.first)) { - auto converted = ConvertAttribute(key_and_value.second); - if (!converted) return {}; - - mlir::StringAttr key = builder_.getStringAttr(key_and_value.first); - attr_array.push_back(builder_.getArrayAttr({key, converted})); - } - } - return builder_.getArrayAttr(attr_array); - } - - // Convert the device attribute in `op` to a device value produced by the - // corresponding GetDeviceOp in the current block. If there does not exist - // one, insert a GetDeviceOp to the beginning of the block and return the - // device value. - Value ConvertDevice(mlir::Operation *op, - ConversionPatternRewriter *rewriter) const { - auto device_attr = op->getAttr("device"); - if (!device_attr) { - op->emitOpError("device attribute not found."); - return {}; - } - - auto device_name = device_attr.cast().getValue(); - if (device_name.empty()) { - op->emitOpError("device has not been assigned."); - return {}; - } - - op->removeAttr(rewriter->getIdentifier("device")); - - auto *block = op->getBlock(); - - if (auto get_device_op = GetDeviceOrNull(device_name, block)) - return get_device_op.device(); - - ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter); - rewriter->setInsertionPointToStart(block); - return rewriter - ->create(block->getParent()->getLoc(), - device_type(), device_name) - .device(); - } - - mlir::Type device_type() const { return device_type_; } - mlir::Type tensor_handle_type() const { return tensor_handle_type_; } - - private: - // TODO(chky): attributes "_output_shapes" should be removed by any tool that - // generates TF MLIR dialect, as they are not used by CoreRuntime. Remove this - // filtering logic once unused attributes are cleaned up in the upper layer. - bool IsUnusedAttribute(llvm::StringRef name) const { - return name == "_output_shapes"; - } - - // Returns the converted attribute in TFRT dialect. If the conversion fails, - // returns a null attribute instead. - mlir::Attribute ConvertAttribute(mlir::Attribute attr) { - // The supported attributes here should be kept consistent with - // //third_party/tf_runtime/include/tfrt/core_runtime/op_attr_type.h - // - // Currently, not all tensorflow data types are supported. Unranked shape - // attributes are not supported yet. - - // Return directly if the attribute is already supported. - if (attr.isa() || attr.isa() || - attr.isa() || attr.isa() || - attr.isa() || - attr.isa()) - return attr; - - // Convert the attribute to the corresponding format in TFRT dialect if - // needed. - if (auto shape_attr = attr.dyn_cast()) { - if (!shape_attr.hasRank()) return {}; - return tfrt::corert::ShapeAttr::get(builder_.getContext(), - shape_attr.getShape()); - } - - // For arrays, we recursively convert the elements. - if (auto array_attr = attr.dyn_cast()) { - llvm::SmallVector attrs; - attrs.reserve(array_attr.size()); - for (auto attr : array_attr) { - auto converted = ConvertAttribute(attr); - if (!converted) return {}; - attrs.push_back(converted); - } - return builder_.getArrayAttr(attrs); - } - - return {}; - } - - // Find a GetDeviceOp that matches the device_name at the beginning of the - // block. Return nullptr if it does not find one. - tfrt::corert::GetDeviceOp GetDeviceOrNull(StringRef device_name, - Block *block) const { - for (auto &op : *block) { - auto get_device_op = llvm::dyn_cast(&op); - if (!get_device_op) break; - if (get_device_op.device_name() == device_name) return get_device_op; - } - return nullptr; - } - - mlir::Builder builder_; - mlir::Type device_type_; - mlir::Type tensor_handle_type_; -}; - -// Lower a tf.Const op that creates a string tensor to a native -// corert.create_string_tensor op. -class CoreRTConstStringTensorOpConversion - : public mlir::OpConversionPattern { - public: - CoreRTConstStringTensorOpConversion(mlir::MLIRContext *context, - CoreRTConverter *corert_converter) - : mlir::OpConversionPattern(context), - corert_converter_(*corert_converter) {} - - LogicalResult matchAndRewrite( - mlir::TF::ConstOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { // NOLINT - if (!op.dtype().isa()) return failure(); - - DenseStringElementsAttr attr = op.value().cast(); - - llvm::SmallVector values; - values.reserve(attr.getNumElements()); - for (const auto &element : attr.getRawStringData()) - values.push_back(rewriter.getStringAttr( - llvm::StringRef(element.data(), element.size()))); - - // Create the shape attribute from the tensor shape. - ArrayRef shape = op.value().getType().getShape(); - llvm::SmallVector dims; - dims.reserve(shape.size()); - auto i64_type = rewriter.getIntegerType(64); - for (auto dim : shape) - dims.push_back(rewriter.getIntegerAttr(i64_type, dim)); - - auto new_op = rewriter.create( - op.getLoc(), corert_converter_.tensor_handle_type(), - rewriter.getArrayAttr(dims), rewriter.getArrayAttr(values)); - - rewriter.replaceOp(op, new_op.result()); - - return success(); - } - - private: - CoreRTConverter &corert_converter_; -}; - -// Convert TF dialect operations with no side effects to CoreRT ExecuteOp. For -// example, -// -// %0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : -// (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> -// -// is converted to -// -// %result = corert.executeop(%device) -// "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : -// (!corert.tensorhandle, !corert.tensorhandle) -> !corert.tensorhandle -// -// Note that it will fail to match if some attributes are not supported. -template -class CoreRTExecuteOpConversion : public mlir::OpConversionPattern { - public: - CoreRTExecuteOpConversion(mlir::MLIRContext *context, - CoreRTConverter *corert_converter) - : mlir::OpConversionPattern(context), - corert_converter_(*corert_converter) {} - - LogicalResult matchAndRewrite( - TF_Op op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { // NOLINT - mlir::StringAttr op_name = rewriter.getStringAttr(op.getOperationName()); - - llvm::SmallVector result_types; - for (auto type : op.getOperation()->getResultTypes()) - result_types.push_back(corert_converter_.convertType(type)); - - // Get the device, or create one if there does not exist one. - auto device = corert_converter_.ConvertDevice(op, &rewriter); - if (!device) return failure(); - - auto derived_attrs = op.materializeDerivedAttributes(); - for (auto named_attr : derived_attrs) { - op.setAttr(named_attr.first, named_attr.second); - } - - ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op.getAttrs()); - if (!op_attrs) return failure(); - - auto new_op = rewriter.create( - op.getLoc(), result_types, device, operands, op_attrs, op_name); - - rewriter.replaceOp(op, new_op.results()); - return success(); - } - - private: - CoreRTConverter &corert_converter_; -}; - -// Deletes the op and forwards the arguments. -template -class PassThroughConversion : public mlir::OpConversionPattern { - public: - explicit PassThroughConversion(MLIRContext *context) - : mlir::OpConversionPattern(context) {} - - LogicalResult matchAndRewrite( - TF_Op op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { // NOLINT - // Just forward the arguments to results. - rewriter.replaceOp(op, operands); - return success(); - } -}; - -// Convert standard ReturnOp to hex.return. -// -// TODO(chky): conversion to hex kernels should come from a common tf_to_hex -// library. -class ReturnOpConversion : public mlir::OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::ReturnOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); - return success(); - } -}; - -// Convert TF dialect to CoreRT dialect. -class TFToCoreRTConversionPass - : public mlir::PassWrapper> { - void runOnOperation() override { - auto module = getOperation(); - mlir::ConversionTarget target(getContext()); - mlir::OwningRewritePatternList patterns; - if (failed(TFToCoreRTConversionPassRun(&getContext(), &module, &target, - &patterns))) - signalPassFailure(); - } -}; - -} // namespace - -LogicalResult TFToCoreRTConversionPassRun( - mlir::MLIRContext *context, mlir::ModuleOp *module, - mlir::ConversionTarget *target, mlir::OwningRewritePatternList *patterns) { - module->removeAttr("tf_saved_model.semantics"); - - mlir::Builder builder(context); - auto bound_id = builder.getIdentifier("tf_saved_model.bound_input"); - auto path_id = builder.getIdentifier("tf_saved_model.index_path"); - - module->walk([bound_id, path_id, module](mlir::Operation *op) mutable { - if (auto func_op = dyn_cast(op)) { - // Remove tf_saved_model specific function arg attributes. - for (unsigned i = 0, e = func_op.getNumArguments(); i != e; ++i) { - func_op.removeArgAttr(i, bound_id); - func_op.removeArgAttr(i, path_id); - } - for (unsigned i = 0, e = func_op.getNumResults(); i != e; ++i) { - func_op.removeResultAttr(i, bound_id); - func_op.removeResultAttr(i, path_id); - } - if (auto exported_names = func_op.getAttrOfType( - "tf_saved_model.exported_names")) { - // Create a function for each exported name. - // - // TODO(b/148477882): TFRT dialect should have similar concepts of - // exported names so that a function can be referenced by multiple - // exported names. - func_op.removeAttr("tf_saved_model.exported_names"); - for (auto exported_name : exported_names) { - auto exported_func_op = func_op.clone(); - exported_func_op.setName( - exported_name.cast().getValue()); - module->insert(module->begin(), exported_func_op); - } - func_op.erase(); - } - } else if (isa(op)) { - // Remove all global_tensor_ops. - op->erase(); - } - }); - - CoreRTConverter corert_converter(context); - - target->addLegalDialect(); - target->addLegalDialect(); - target->addIllegalDialect(); - target->addDynamicallyLegalOp([&corert_converter](FuncOp op) { - return corert_converter.isSignatureLegal(op.getType()); - }); - - patterns->insert, - PassThroughConversion, ReturnOpConversion>( - context); - - // Here we use one specialized pattern for tf.Const with string tensors as - // it will incorrect to use ExecuteOp pattern to convert string tensor - // attribute. - patterns->insert(context, - &corert_converter); - - // TODO(b/148823030): Pattern registration for TF operations is not - // sustainable currently. We need to figure out a plan - patterns->insert, - // TODO(chky): Move the ReadVariableOp + Identity pattern - // to optimizer. - // CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion>(context, - &corert_converter); - - mlir::populateFuncOpTypeConversionPattern(*patterns, context, - corert_converter); - return mlir::applyPartialConversion(*module, *target, *patterns); -} - -std::unique_ptr CreateTFToCoreRTConversionPass() { - return std::make_unique(); -} - -void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm, - const CoreRTPipelineOptions &options) { - // First, we prune unused operations in MLIR in TF Executor dialect. - pm.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass()); - - // Then we pass the MLIR module through the TF standard pipeline, which for - // instances does shape inference, canonicalization, inlining, etc. - mlir::TF::StandardPipelineOptions tf_options; - tf_options.enable_inliner = true; - mlir::TF::CreateTFStandardPipeline(pm, tf_options); - - // After all standard passes run layout optimization to assign optimal data - // format for all layout sensitive operations. - mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; - layout_optimization_options.force_data_format = - options.force_data_format.getValue(); - mlir::TF::CreateLayoutOptimizationPipeline(pm, layout_optimization_options); - - // Run canonicalization pipeline to remove unused constants and bypassed - // transpose operations left in the IR after layout optimization. - pm.addNestedPass(mlir::createCanonicalizerPass()); - - if (options.default_device == "gpu") - pm.addNestedPass(mlir::TF::CreateGpuOpFusionPass()); - - // Then we assign default devices. - pm.addNestedPass( - mlir::TF::CreateSimpleTFDeviceAssignmentPass(options.default_device)); -} - -void CreateTFExecutorToCoreRTPipeline(mlir::OpPassManager &pm, - const CoreRTPipelineOptions &options) { - CreateTFExecutorToTFPipeline(pm, options); - - // Convert it to MLIR in CoreRT dialect. - pm.addPass(CreateTFToCoreRTConversionPass()); - - // Run optimizer on the MLIR module in CoreRT dialect. - if (options.enable_optimizer) - pm.addNestedPass(CreateCoreRTOptimizePass()); -} - -static mlir::PassRegistration pass( - "tf-to-corert", - "Convert Tensorflow dialect to TFRT's CoreRuntime dialect."); - -static mlir::PassPipelineRegistration pipeline( - "tf-executor-to-corert-pipeline", - "Convert Tensorflow Executor dialect to TFRT's CoreRuntime dialect, and " - "also apply necessary optimization passes.", - CreateTFExecutorToCoreRTPipeline); - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD new file mode 100644 index 00000000000..27a8dbd2809 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -0,0 +1,50 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +licenses(["notice"]) + +cc_library( + name = "cubin_creator", + srcs = ["cubin_creator.cc"], + hdrs = ["cubin_creator.h"], + copts = if_cuda(["-DGOOGLE_CUDA=1"]), + deps = [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TargetNVVMIR", + "@llvm-project//mlir:Transforms", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", # buildcleaner: keep + "//tensorflow/compiler/mlir/xla:xla_unfuse_batch_norm", # buildcleaner: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service/gpu:stream_executor_util", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", + "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core:lib", + ] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]), +) + +tf_cc_binary( + name = "tf_to_cubin", + srcs = ["tf_to_cubin.cc"], + visibility = ["//tensorflow/core/kernels/cubin_headers:__pkg__"], + deps = [ + ":cubin_creator", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc new file mode 100644 index 00000000000..f47485d0214 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -0,0 +1,270 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- cubin_creator.cc -----------------------------------------*- C++ -*-===// +// +// This file implements the function to compile a TF kernel function to a cubin. +// +//===----------------------------------------------------------------------===// +#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/gpu/asm_compiler.h" +#endif + +namespace { +using tensorflow::Status; +using xla::InternalError; +using xla::StatusOr; + +StatusOr GetLibdeviceDir( + const xla::HloModuleConfig& hlo_module_config) { + for (const std::string& cuda_root : tensorflow::CandidateCudaRoots( + hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { + std::string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + return InternalError( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); +} + +struct MaterializeBroadcastsPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::ConversionTarget conversionTarget(getContext()); + mlir::OwningRewritePatternList conversionPatterns; + + // Consider the xla_hlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + + mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(), + &conversionTarget); + mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(), + &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +struct UnfuseBatchNormPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + } +}; + +Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { + mlir::PassManager pm(module.getContext()); + auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, llvm::dbgs()); + pm.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(false)); + pm.addNestedPass( + absl::make_unique()); + pm.addNestedPass(absl::make_unique()); + pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass()); + pm.addNestedPass(mlir::xla_lhlo::createLhloCopyRemovalPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering TF to LHLO failed."); + } + return Status::OK(); +} + +struct PropagateStaticKnowledge + : public mlir::PassWrapper> { + explicit PropagateStaticKnowledge(mlir::FunctionType type, + llvm::ArrayRef same_shape_) + : func_type(type), same_shape(same_shape_) {} + + void runOnOperation() override { + // We know due to tensorflow ABI that the offset is always 0 and that the + // innermost stride is always 1. To make this visible to the compiler, + // we insert constants into the code and replace usages accordingly. + // We do not change the signature so that we keep a somewhat stable ABI + // that is easy to undertand by tools. + mlir::LLVM::LLVMFuncOp func = getOperation(); + mlir::OpBuilder b(func.getBody()); + auto index_type = func.getArgument(3).getType(); + mlir::Value one = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1)); + mlir::Value zero = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0)); + uint32_t arg_pos = 0; + std::vector positions; + for (mlir::Type arg_type : func_type.getInputs()) { + positions.push_back(arg_pos); + func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); + arg_pos += 3 + arg_type.cast().getRank() * 2; + func.getArgument(arg_pos - 1).replaceAllUsesWith(one); + } + + // If we have knowledge that some arguments have the same shape, we + // can use that here. Simply replace usages of the shape parameters within + // the function body to a single shape parameter. + if (!same_shape.empty()) { + auto first = same_shape.front(); + auto first_offset = positions.at(first); + mlir::ShapedType first_type = + func_type.getInput(first).cast(); + uint32_t rank = first_type.getRank(); + for (auto same : same_shape.drop_front(1)) { + uint32_t same_offset = positions.at(same); + auto same_type = func_type.getInput(same).cast(); + if (same_type.getRank() != rank) { + func.emitOpError() << "same shape constraints on arguments with " + "non-matching shapes: #" + << first << " and #" << same; + signalPassFailure(); + } + + for (uint32_t i = 0; i < 2 * rank; ++i) { + // Replace uses for second arg data with first arg. + auto same_arg = func.getArgument(same_offset + 3 + i); + auto first_arg = func.getArgument(first_offset + 3 + i); + same_arg.replaceAllUsesWith(first_arg); + } + } + } + } + + mlir::FunctionType func_type; + llvm::ArrayRef same_shape; +}; + +Status PropagateStaticShapeKnowledgeToKernel( + mlir::ModuleOp module, llvm::ArrayRef same_shape) { + // Grab the original signature from the single function. + auto func = *module.getBody()->op_begin(); + + mlir::PassManager pm(module.getContext()); + auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, llvm::dbgs()); + auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>(); + kernel_pm.addNestedPass( + absl::make_unique(func.getType(), same_shape)); + + if (failed(pm.run(module))) { + return InternalError("Static knowledge propagation failed."); + } + return Status::OK(); +} +} // namespace + +StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( + llvm::StringRef tf_code, std::pair compute_capability, + llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + mlir::MLIRContext context; + context.allowUnregisteredDialects(); // TODO(b/152572127) + mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); + + TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get())); + TF_RETURN_IF_ERROR( + xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors, + /*collapseParallelLoops=*/false)); + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); + // TODO(b/156985522): Figure out why we get a segfault when generating Tanh + // with 'same_shape' containing {0, 1}. We would also get the crash if we + // unconditionally call PropagateStaticShapeKnowledgeToKernel while + // 'same_shape' is empty. + if (!same_shape.empty()) { + TF_RETURN_IF_ERROR( + PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + } + + mlir::OwningModuleRef kernel_module = + xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to NVVM"); + } + + llvmModule->setModuleIdentifier("acme"); + llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); + TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx( + llvmModule.get(), compute_capability, + config, libdevice_dir)); + VLOG(1) << ptx; + +#if GOOGLE_CUDA + return tensorflow::se::CompileGpuAsm( + std::get<0>(compute_capability), std::get<1>(compute_capability), + ptx.c_str(), xla::gpu::PtxOptsFromConfig(config)); +#else + return InternalError( + "GOOGLE_CUDA not defined. Did you specify --config=cuda ?"); +#endif +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h new file mode 100644 index 00000000000..47626ba9d0d --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- cubin_creator.h ------------------------------------------*- C++ -*-===// +// +// This file declares the function to compile a TF kernel function to a cubin. +// +//===----------------------------------------------------------------------===// +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { +namespace kernel_gen { +xla::StatusOr> GenerateCubinForTfCode( + llvm::StringRef tf_code, + std::pair compute_capability = {7, 5}, + llvm::ArrayRef tile_sizes = {16, 64}, + llvm::ArrayRef same_shape = {}, + llvm::ArrayRef unroll_factors = {}); +} // namespace kernel_gen +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc new file mode 100644 index 00000000000..8edc567e777 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc @@ -0,0 +1,118 @@ +// Copyright 2020 The TensorFlow Runtime Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===// +// +// This file implements the entry point to compile a tf op to a cubin file. +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +bool ParseStringList(std::string string_list, std::vector* result) { + result->clear(); + uint32_t item; + auto items = absl::StrSplit(string_list, ','); + for (const auto& item_str : items) { + if (!absl::SimpleAtoi(item_str, &item)) { + LOG(ERROR) << "Expected token " << item_str << " to be an integer"; + return false; + } + result->push_back(item); + } + return true; +} +} // namespace + +int main(int argc, char** argv) { + std::string output_file = "foo.bin"; + int32_t architecture = 50; + std::vector tile_sizes; + std::vector unroll_factors; + std::vector same_shape; + + auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) { + if (!ParseStringList(tile_sizes_str, &tile_sizes)) { + return false; + } + // Initialize with the default. + if (tile_sizes.empty()) { + tile_sizes.push_back(16); + tile_sizes.push_back(64); + } + return true; + }; + + auto parse_unroll_factors = + [&unroll_factors](std::string unroll_factors_str) { + return ParseStringList(unroll_factors_str, &unroll_factors); + }; + + auto parse_same_shape = [&same_shape](std::string same_shape_str) { + return ParseStringList(same_shape_str, &same_shape); + }; + + std::vector flag_list = { + tensorflow::Flag("output", &output_file, "output file"), + tensorflow::Flag("arch", &architecture, + "target architecture (e.g. 50 for sm_50)"), + tensorflow::Flag("tile_sizes", parse_tile_sizes, "16,64", + "tile sizes to use"), + tensorflow::Flag("unroll_factors", parse_unroll_factors, "", + "factors to unroll by, separated by commas"), + tensorflow::Flag("same_shape", parse_same_shape, "", + "arguments with same shape, separated by commas"), + }; + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain("usage", &argc, &argv); + if (!parse_ok) { + return 1; + } + + std::pair compute_capability(architecture / 10, + architecture % 10); + + auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode( + argv[1], compute_capability, tile_sizes, same_shape, unroll_factors); + + if (!cubin.ok()) { + LOG(ERROR) << cubin.status(); + return 1; + } + + std::vector cubin_data = cubin.ConsumeValueOrDie(); + + auto status = tensorflow::WriteStringToFile( + tensorflow::Env::Default(), output_file, + absl::string_view{reinterpret_cast(cubin_data.data()), + cubin_data.size()}); + + if (!status.ok()) { + LOG(ERROR) << status; + return 1; + } + + return 0; +} diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index e4309d5eef0..736651b5022 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -23,7 +23,6 @@ package_group( "//tensorflow/compiler/xla/...", "//third_party/iree/...", "//third_party/mlir_edge/...", - "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) @@ -39,7 +38,8 @@ filegroup( "ir/lhlo_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", + "@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td", ], ) @@ -51,9 +51,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/chlo_ops.td", - td_srcs = [ - ":hlo_ops_td_files", - ], + td_srcs = [":hlo_ops_td_files"], ) gentbl( @@ -113,16 +111,11 @@ gentbl( gentbl( name = "xla_canonicalize_inc_gen", tbl_outs = [ - ( - "-gen-rewriters", - "transforms/generated_canonicalize.inc", - ), + ("-gen-rewriters", "transforms/generated_canonicalize.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/canonicalize.td", - td_srcs = [ - ":hlo_ops_td_files", - ], + td_srcs = [":hlo_ops_td_files"], ) cc_library( @@ -133,6 +126,7 @@ cc_library( "transforms/legalize_tf_control_flow.cc", ], deps = [ + ":chlo_legalize_to_hlo", ":convert_op_folder", ":hlo", "//tensorflow/compiler/mlir/tensorflow", @@ -157,9 +151,7 @@ cc_library( cc_library( name = "xla_legalize_tf_with_tf2xla", - srcs = [ - "transforms/legalize_tf_with_tf2xla.cc", - ], + srcs = ["transforms/legalize_tf_with_tf2xla.cc"], deps = [ ":hlo", ":mlir_hlo_builder", @@ -193,6 +185,22 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_sink_constants_to_control_flow", + srcs = ["transforms/sink_constants_to_control_flow.cc"], + deps = [ + ":hlo", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "map_xla_to_scalar_op", hdrs = ["transforms/map_xla_to_scalar_op.h"], @@ -241,14 +249,29 @@ cc_library( "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LoopOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) +cc_library( + name = "lhlo_legalize_to_llvm", + srcs = ["transforms/lhlo_legalize_to_llvm.cc"], + deps = [ + ":lhlo", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "xla_legalize_to_linalg", srcs = ["transforms/xla_legalize_to_linalg.cc"], @@ -267,6 +290,21 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_transform_unranked_hlo", + srcs = ["transforms/xla_transform_unranked_hlo.cc"], + deps = [ + ":hlo", + "@com_google_absl//absl/memory", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "lhlo_legalize_to_gpu", srcs = ["transforms/lhlo_legalize_to_gpu.cc"], @@ -279,8 +317,8 @@ cc_library( "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LoopOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -331,42 +369,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "buffer_assignment", - srcs = ["transforms/buffer_assignment.cc"], - hdrs = ["transforms/buffer_assignment.h"], - deps = [ - "@com_google_absl//absl/memory", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - ], - alwayslink = 1, -) - -cc_library( - name = "buffer_assignment_test", - srcs = ["transforms/buffer_assignment_test.cc"], - hdrs = [ - "transforms/buffer_assignment.h", - "transforms/passes.h", - ], - deps = [ - "@com_google_absl//absl/memory", - "@llvm-project//llvm:support", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - ], - alwayslink = 1, -) - gentbl( name = "xla_legalize_to_standard_inc_gen", tbl_outs = [ @@ -382,9 +384,7 @@ gentbl( cc_library( name = "xla_legalize_control_flow", - srcs = [ - "transforms/legalize_control_flow.cc", - ], + srcs = ["transforms/legalize_control_flow.cc"], deps = [ ":hlo", "@llvm-project//llvm:support", @@ -471,9 +471,7 @@ cc_library( cc_library( name = "xla_materialize_broadcasts", - srcs = [ - "transforms/materialize_broadcasts.cc", - ], + srcs = ["transforms/materialize_broadcasts.cc"], deps = [ ":hlo", "@llvm-project//mlir:IR", @@ -484,9 +482,7 @@ cc_library( cc_library( name = "xla_unfuse_batch_norm", - srcs = [ - "transforms/unfuse_batch_norm.cc", - ], + srcs = ["transforms/unfuse_batch_norm.cc"], deps = [ ":hlo", "@llvm-project//llvm:support", @@ -498,9 +494,7 @@ cc_library( cc_library( name = "chlo_legalize_to_hlo", - srcs = [ - "transforms/chlo_legalize_to_hlo.cc", - ], + srcs = ["transforms/chlo_legalize_to_hlo.cc"], deps = [ ":hlo", "@llvm-project//mlir:IR", @@ -513,6 +507,7 @@ cc_library( name = "xla_test_passes", srcs = [ "transforms/chlo_legalize_to_hlo_pass.cc", + "transforms/lhlo_legalize_to_llvm_pass.cc", "transforms/materialize_broadcasts_pass.cc", "transforms/test_infer_shaped_type_pass.cc", "transforms/unfuse_batch_norm_pass.cc", @@ -520,10 +515,14 @@ cc_library( deps = [ ":chlo_legalize_to_hlo", # build-cleaner: keep ":hlo", + ":lhlo", + ":lhlo_legalize_to_llvm", # build-cleaner: keep ":xla_materialize_broadcasts", # build-cleaner: keep ":xla_unfuse_batch_norm", # build-cleaner: keep "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", @@ -575,12 +574,8 @@ cc_library( cc_library( name = "mlir_hlo_builder", - srcs = [ - "ir/mlir_hlo_builder.cc", - ], - hdrs = [ - "ir/mlir_hlo_builder.h", - ], + srcs = ["ir/mlir_hlo_builder.cc"], + hdrs = ["ir/mlir_hlo_builder.h"], deps = [ ":attribute_importer", ":hlo", @@ -626,6 +621,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", ], alwayslink = 1, ) @@ -823,7 +819,7 @@ genrule( name = "operator_writer_inc", srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", ":ir/hlo_ops.td", ":ir/hlo_ops_base.td", @@ -853,8 +849,6 @@ cc_library( "//tensorflow/compiler/mlir:__subpackages__", ], deps = [ - ":buffer_assignment", - ":buffer_assignment_test", ":chlo_legalize_to_hlo", ":hlo", ":hlo_legalize_to_lhlo", @@ -872,8 +866,9 @@ cc_library( ":xla_legalize_to_linalg", ":xla_legalize_to_standard", ":xla_lower", - ":xla_materialize_broadcasts", + ":xla_sink_constants_to_control_flow", ":xla_test_passes", + ":xla_transform_unranked_hlo", ], ) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 5dc610a5670..22a0b038833 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -420,15 +420,37 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kConditional: { llvm::SmallVector rets; - TF_RETURN_IF_ERROR(GetMlirTypes( - {instruction->true_computation()->root_instruction()}, &rets)); + mlir::Type pred_or_index_type = + operands[0].getType().cast().getElementType(); + // It is a predicated conditional if first argument is a boolean and + // should be mapped to If op. + if (pred_or_index_type.isInteger(1)) { + TF_RETURN_IF_ERROR(GetMlirTypes( + {instruction->true_computation()->root_instruction()}, &rets)); - auto op = func_builder->create( - loc, rets, operands, attributes); - TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), - &op.true_branch())); - TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), - &op.false_branch())); + auto op = func_builder->create(loc, rets, operands, + attributes); + TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), + &op.true_branch())); + TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), + &op.false_branch())); + return op.getOperation(); + } + + // Otherwise, it is a indexed conditional and should be mapped to Case op. + TF_RETURN_IF_ERROR(GetMlirTypes( + {instruction->branch_computation(0)->root_instruction()}, &rets)); + + int num_branches = instruction->branch_count(); + auto op = func_builder->create( + loc, rets, operands, attributes, num_branches); + for (auto index_and_computation : + llvm::enumerate(instruction->branch_computations())) { + auto index = index_and_computation.index(); + HloComputation* computation = index_and_computation.value(); + TF_RETURN_IF_ERROR( + ImportComputation(computation, &op.branches()[index])); + } return op.getOperation(); } case HloOpcode::kConcatenate: { diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index c685cc296fd..dc801f64ede 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -139,6 +139,10 @@ StatusOr CreateDenseElementsAttrFromLiteral( return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::U64: return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::C64: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::C128: + return CreateDenseAttrFromLiteral(type, literal); default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type))); diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc index bc6842a617e..26db4549a2a 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -97,16 +97,12 @@ static Type GetBroadcastType(Type x, Type y, Type element_type, LogicalResult InferBroadcastBinaryOpReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, Type element_type, + DictionaryAttr attributes, Type element_type, SmallVectorImpl& inferedReturnShapes) { // Find broadcast_dimensions. - DenseIntElementsAttr broadcast_dimensions; - for (auto attr : attributes) { - if (attr.first == "broadcast_dimensions") { - broadcast_dimensions = attr.second.dyn_cast(); - break; - } - } + DenseIntElementsAttr broadcast_dimensions = + attributes.get("broadcast_dimensions") + .dyn_cast_or_null(); ShapedType lhs_type = operands[0].getType().dyn_cast(); ShapedType rhs_type = operands[1].getType().dyn_cast(); @@ -168,7 +164,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( LogicalResult BroadcastComplexOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { ShapedType lhs_type = operands[0].getType().dyn_cast(); if (!lhs_type) { @@ -189,9 +185,19 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( // BroadcastCompareOp (has custom type inference due to different result type). //===----------------------------------------------------------------------===// +void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, + Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dimensions, + StringAttr comparison_direction) { + auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), + builder.getI1Type(), broadcast_dimensions); + build(builder, result, new_type, lhs, rhs, broadcast_dimensions, + comparison_direction); +} + LogicalResult BroadcastCompareOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { Type element_type = IntegerType::get(1, context); return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, @@ -211,7 +217,7 @@ LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ LogicalResult Op::inferReturnTypeComponents( \ MLIRContext* context, Optional location, ValueRange operands, \ - ArrayRef attributes, RegionRange regions, \ + DictionaryAttr attributes, RegionRange regions, \ SmallVectorImpl& inferedReturnShapes) { \ return InferBroadcastBinaryOpReturnTypeComponents( \ context, location, operands, attributes, /*element_type=*/nullptr, \ diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.h b/tensorflow/compiler/mlir/xla/ir/chlo_ops.h index 474d4b7d95a..a5337907579 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.h @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { namespace xla_chlo { diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td index a244985c9b5..febc99f6b72 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td @@ -31,7 +31,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" def HLOClient_Dialect : Dialect { @@ -360,6 +360,11 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< HLO_ComparisonDirectionAttr:$comparison_direction ); let results = (outs HLO_PredTensor); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + >]; } #endif // CHLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index cb7372a762c..c66b8f12332 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -262,6 +262,34 @@ static LogicalResult Verify(IotaOp op) { return success(); } +//===----------------------------------------------------------------------===// +// DynamicIotaOp +//===----------------------------------------------------------------------===// + +namespace { + +struct DynamicIotaIsStatic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter& rewriter) const override { + auto result_ty = iota.getType().cast(); + if (!result_ty.hasStaticShape()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(iota, result_ty, iota.iota_dimension()); + return success(); + } +}; + +} // namespace + +void DynamicIotaOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // AbsOp //===----------------------------------------------------------------------===// @@ -567,48 +595,6 @@ OpFoldResult BroadcastInDimOp::fold(ArrayRef) { return getOperand(); } -//===----------------------------------------------------------------------===// -// ScalarsToDimensionTensorOp -//===----------------------------------------------------------------------===// - -namespace { - -// Canonicalizes the pattern of the form -// -// %2 = "xla_hlo.scalars_to_dimension_tensor"(%0, %1) -// : (i32, i32) -> tensor<2xi32> -// %3 = extract_element %2[%c0] : tensor<2xi32> -// -// to just %0. -struct ExtractElementFromScalarsToDimensionTensor - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractElementOp extract, - PatternRewriter& rewriter) const override { - if (extract.indices().size() != 1) return failure(); - - if (auto scalars_to_tensor = dyn_cast_or_null( - extract.aggregate().getDefiningOp())) { - APInt index; - if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) { - return failure(); - } - rewriter.replaceOp(extract, - scalars_to_tensor.getOperand(index.getZExtValue())); - return success(); - } - return failure(); - } -}; - -} // namespace - -void ScalarsToDimensionTensorOp::getCanonicalizationPatterns( - OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // DynamicBroadcastInDimOp //===----------------------------------------------------------------------===// @@ -1170,9 +1156,22 @@ OpFoldResult CopyOp::fold(ArrayRef operands) { return getOperand(); } //===----------------------------------------------------------------------===// OpFoldResult ReverseOp::fold(ArrayRef operands) { + auto input = operand(); + // No dimensions to reverse. - if (dimensions().getNumElements() == 0) return operand(); - return nullptr; + if (dimensions().getNumElements() == 0) return input; + + llvm::SmallVector new_dims; + new_dims.reserve(dimensions().getNumElements()); + + auto shaped_type = input.getType().cast(); + for (auto dim : dimensions().getValues()) { + if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) { + return nullptr; + } + } + + return input; } //===----------------------------------------------------------------------===// @@ -1240,7 +1239,7 @@ static LogicalResult Verify(SelectOp op) { // the return type based on operand type. LogicalResult SelectOp::inferReturnTypes( MLIRContext*, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { auto x_type = operands[1].getType(); auto y_type = operands[2].getType(); @@ -1345,19 +1344,23 @@ static LogicalResult Verify(PadOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(ReshapeOp op) { - auto operand_ty = op.operand().getType().cast(); + // If the operand type is dynamically shaped there is nothing to verify. + auto operand_ty = op.operand().getType().cast(); if (!operand_ty || !operand_ty.hasStaticShape()) return success(); - int64_t num_input_elements = operand_ty.getNumElements(); - auto out_ty = op.getType().cast(); - if (out_ty && out_ty.hasStaticShape()) { - int64_t num_output_elements = out_ty.getNumElements(); - if (num_input_elements != num_output_elements) - return op.emitOpError() - << "number of output elements (" << num_output_elements - << ") doesn't match expected number of elements (" - << num_input_elements << ")"; - } + // If the operand type is statically shaped (not required) the number of + // elements must match that of the result type. + auto result_ty = op.getType().cast(); + assert(result_ty && result_ty.hasStaticShape() && + "result type must be statically shaped"); + int64_t num_result_elements = result_ty.getNumElements(); + int64_t num_operand_elements = operand_ty.getNumElements(); + if (num_result_elements != num_operand_elements) + return op.emitOpError() + << "number of output elements (" << num_result_elements + << ") doesn't match expected number of elements (" + << num_operand_elements << ")"; + return success(); } @@ -1379,94 +1382,71 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// Case Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CaseOp op) { + auto num_branches = op.branches().size(); + if (op.branch_operands().size() != num_branches) + return op.emitOpError() << "expects number of branches " << num_branches + << " to be same as number of branch operands " + << op.branch_operands().size(); + + MutableArrayRef branches = op.branches(); + OperandRange branch_operands = op.branch_operands(); + for (unsigned i = 0; i < num_branches; ++i) { + mlir::Region& branch_region = branches[i]; + if (branch_region.empty()) + return op.emitOpError() << "cannot have empty regions"; + mlir::Block& entry_block = branch_region.front(); + if (entry_block.getNumArguments() != 1) + return op.emitOpError() + << "expects branch regions to have single argument, but found " + << entry_block.getNumArguments() << " for branch " << i; + auto operand = branch_operands[i]; + if (entry_block.getArgument(0).getType() != operand.getType()) + return op.emitOpError() + << "expects operand " << i + 1 << " to be of type " + << entry_block.getArgument(0).getType() << ", but found " + << operand.getType(); + WalkResult walker = branch_region.walk([&](ReturnOp return_op) { + if (return_op.getOperands().getTypes() != op.getResultTypes()) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walker.wasInterrupted()) + return op.emitOpError() + << "branch " << i + << " returned values do not match op result types"; + } + return success(); +} + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// namespace { -// Gets the resulting type from a broadcast between two types. -static Type GetBroadcastType(Builder* builder, Type x, Type y, - Type element_type, - DenseIntElementsAttr broadcast_dimensions) { + +// Updates the element type of a (presumed) tensor type 'x', returning either +// a permuted UnrankedTensorType or RankedTensorType. +static Type UpdateResultElementType(Builder* builder, Type x, + Type element_type) { auto x_ranked = x.dyn_cast(); - auto y_ranked = y.dyn_cast(); - if (!x_ranked || !y_ranked) { + if (!x_ranked) { return UnrankedTensorType::get(element_type); } auto shape_x = x_ranked.getShape(); - auto shape_y = y_ranked.getShape(); - - if (shape_x.size() == shape_y.size()) { - llvm::SmallVector out_shape(shape_x.size()); - for (int i = 0; i < shape_x.size(); i++) { - auto x_val = shape_x[i]; - auto y_val = shape_y[i]; - if (x_val == -1 || y_val == -1) { - out_shape[i] = -1; - } else { - out_shape[i] = std::max(x_val, y_val); - } - } - return RankedTensorType::get(out_shape, element_type); - } - - // Return unranked tensor for invalid broadcast dimensions. - if (!broadcast_dimensions) return UnrankedTensorType::get(element_type); - - auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; - auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; - - llvm::SmallVector out_shape(shape_large.begin(), - shape_large.end()); - - // Update according to the broadcast dimensions. - for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - auto old_value = out_shape[index_pair.value().getSExtValue()]; - auto new_value = shape_small[index_pair.index()]; - if (old_value != -1 && (new_value == -1 || new_value > old_value)) { - out_shape[index_pair.value().getSExtValue()] = new_value; - } - } - - return RankedTensorType::get(out_shape, element_type); + return RankedTensorType::get(shape_x, element_type); } } // namespace -#define BINARY_BUILDER(Op) \ - void Op::build(OpBuilder& builder, OperationState& result, Value left, \ - Value right, DenseIntElementsAttr broadcast_dimensions) { \ - auto type = GetBroadcastType(&builder, left.getType().cast(), \ - right.getType().cast(), \ - getElementTypeOrSelf(right.getType()), \ - broadcast_dimensions); \ - return Op::build(builder, result, type, left, right, \ - broadcast_dimensions); \ - } - -BINARY_BUILDER(AddOp); -BINARY_BUILDER(AndOp); -BINARY_BUILDER(Atan2Op); -BINARY_BUILDER(DivOp); -BINARY_BUILDER(MaxOp); -BINARY_BUILDER(MinOp); -BINARY_BUILDER(MulOp); -BINARY_BUILDER(OrOp); -BINARY_BUILDER(PowOp); -BINARY_BUILDER(RemOp); -BINARY_BUILDER(ShiftLeftOp); -BINARY_BUILDER(ShiftRightArithmeticOp); -BINARY_BUILDER(ShiftRightLogicalOp); -BINARY_BUILDER(SubOp); -BINARY_BUILDER(XorOp); - -#undef BINARY_BUILDER - template static Attribute BinaryFolder(Op* op, ArrayRef attrs) { if (!attrs[0] || !attrs[1]) return {}; - if (op->broadcast_dimensions().hasValue()) return {}; DenseElementsAttr lhs = attrs[0].dyn_cast(); DenseElementsAttr rhs = attrs[1].dyn_cast(); @@ -1494,6 +1474,38 @@ static Attribute BinaryFolder(Op* op, ArrayRef attrs) { return DenseElementsAttr::get(type, values); } +template +struct divide : std::divides {}; + +template <> +struct divide { + APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } +}; + +template +struct max { + T operator()(const T& a, const T& b) const { return std::max(a, b); } +}; + +template <> +struct max { + APInt operator()(const APInt& a, const APInt& b) const { + return llvm::APIntOps::smax(a, b); + } +}; + +template +struct min { + T operator()(const T& a, const T& b) const { return std::min(a, b); } +}; + +template <> +struct min { + APInt operator()(const APInt& a, const APInt& b) const { + return llvm::APIntOps::smin(a, b); + } +}; + #define BINARY_FOLDER(Op, Func) \ OpFoldResult Op::fold(ArrayRef attrs) { \ if (getElementTypeOrSelf(getType()).isa()) \ @@ -1503,9 +1515,16 @@ static Attribute BinaryFolder(Op* op, ArrayRef attrs) { return {}; \ } +// Addition, subtraction and multiplication use the std:: versions of the ops. +// Due to the other ops behaving differently in signed vs unsigned integers, +// APInts need a special implementation. Currently, it replicates signed int +// op behavior. BINARY_FOLDER(AddOp, std::plus); BINARY_FOLDER(SubOp, std::minus); BINARY_FOLDER(MulOp, std::multiplies); +BINARY_FOLDER(DivOp, divide); +BINARY_FOLDER(MaxOp, max); +BINARY_FOLDER(MinOp, min); #undef BINARY_FOLDER @@ -1876,12 +1895,10 @@ void UnaryEinsumOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, - Value rhs, DenseIntElementsAttr broadcast_dimensions, - StringAttr comparison_direction) { - auto new_type = GetBroadcastType(&builder, lhs.getType(), rhs.getType(), - builder.getI1Type(), broadcast_dimensions); - build(builder, result, new_type, lhs, rhs, broadcast_dimensions, - comparison_direction); + Value rhs, StringAttr comparison_direction) { + auto new_type = + UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); + build(builder, result, new_type, lhs, rhs, comparison_direction); } #define GET_OP_CLASSES @@ -1969,11 +1986,8 @@ LogicalResult deriveShapeFromFirstOperand( loc, builder->getI64IntegerAttr(element.value()))); } } - *reifiedReturnShapes = - SmallVector{builder->create( - loc, - RankedTensorType::get({operand_type.getRank()}, shape_scalar_type), - shape_values)}; + *reifiedReturnShapes = SmallVector{ + builder->create(loc, shape_values)}; return success(); } diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index 25b2f009cc6..9725a0684f6 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index dabf03d3c9f..97b8e1c1863 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -23,7 +23,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" @@ -79,6 +79,25 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { let hasCustomHLOConverter = 1; } +def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { + let summary = "Create linear increasing values from 0 to length -1."; + let description = [{ + Produces an HLO Tensor of the specified shape, with an incremental set of + values along the specified dimension starting at 0. + + Requires: + - The output length of the tensor result. + }]; + + let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); + let results = (outs HLO_Tensor:$result); + + let hasCanonicalizer = 1; + // Cannot be exported to legacy formats. + let hasCustomHLOConverter = 1; +} + + def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { string summary = "Create Token operator"; @@ -95,6 +114,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { // XLA unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + class HLO_UnaryElementwiseOp traits, Type TensorType>: HLO_Op { @@ -103,8 +123,7 @@ class HLO_UnaryElementwiseOp traits, let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, - ValueRange operands, ArrayRef attributes, - RegionRange regions, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } @@ -161,6 +180,16 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; +def HLO_ImagOp: HLO_Op< + "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_IsFiniteOp { @@ -188,6 +217,16 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; +def HLO_RealOp: HLO_Op< + "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; @@ -209,67 +248,25 @@ def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", BASE_HLO_SqrtOp; def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", - [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType], + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; -//===----------------------------------------------------------------------===// -// XLA complex unary elementwise op definitions. -//===----------------------------------------------------------------------===// -// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions - -def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>, - BASE_HLO_ComplexOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; - - let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); - let results = (outs HLO_ComplexTensor); - let hasFolder = 1; -} - -def HLO_ImagOp: HLO_Op< - "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); - let results = (outs HLO_FpTensor); - let hasFolder = 1; -} - -def HLO_RealOp: HLO_Op< - "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); - let results = (outs HLO_FpTensor); - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// - // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + class HLO_BinaryElementwiseOp traits> : HLO_Op { let arguments = (ins HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - OptionalAttr:$broadcast_dimensions + HLO_Tensor:$rhs ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions" - >]; - let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } @@ -286,47 +283,61 @@ class HLO_BinaryElementwiseOp traits> : } def HLO_AddOp : HLO_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp { let hasFolder = 1; } def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; + +def HLO_ComplexOp: HLO_Op<"complex", + [NoSideEffect, SameOperandsAndResultShape]>, + BASE_HLO_ComplexOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; + + let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); + let results = (outs HLO_ComplexTensor); + let hasFolder = 1; +} def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp { + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp { + let hasFolder = 1; } def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp { + let hasFolder = 1; } def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp { + let hasFolder = 1; } def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp { let hasFolder = 1; } def HLO_PowOp : HLO_BinaryElementwiseOp<"power", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp; def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp; def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp { + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp { let hasFolder = 1; } @@ -336,11 +347,11 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class HLO_BinaryLogicalElementwiseOp : - HLO_BinaryElementwiseOp { + HLO_BinaryElementwiseOp< + mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins HLO_PredOrIntTensor:$lhs, - HLO_PredOrIntTensor:$rhs, - OptionalAttr:$broadcast_dimensions + HLO_PredOrIntTensor:$rhs ); } @@ -472,7 +483,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, // XLA control flow op definitions. //===----------------------------------------------------------------------===// -def HLO_AfterAllOp : HLO_Op<"after_all", []> { +def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { string summary = "AfterAll operator"; @@ -489,8 +500,11 @@ def HLO_AfterAllOp : HLO_Op<"after_all", []> { let results = (outs HLO_Token); } -def HLO_ConditionalOp: HLO_Op<"conditional", []> { - string summary = "Conditional operator"; +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. IfOp maps to predicated +// conditional use of kConditional HLO. +def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> { + string summary = "If operator"; string description = [{ Returns the result of executing either a true or false function depending on @@ -505,7 +519,8 @@ def HLO_ConditionalOp: HLO_Op<"conditional", []> { HLO_TensorOrTuple:$false_arg ); - let regions = (region AnyRegion:$true_branch, AnyRegion:$false_branch); + let regions = (region AnyRegion:$true_branch, + AnyRegion:$false_branch); let results = (outs HLO_TensorOrTuple); @@ -513,7 +528,27 @@ def HLO_ConditionalOp: HLO_Op<"conditional", []> { let hasCustomHLOConverter = 1; } -def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> { +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. CaseOp maps to indexed +// conditional use of kConditional HLO. +def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>, + BASE_HLO_CaseOp { + + let arguments = (ins + I32Tensor:$index, + Variadic:$branch_operands + ); + + let regions = (region VariadicRegion:$branches); + + let results = (outs Variadic); + + let hasCustomHLOConverter = 1; +} + + +def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects, + SameOperandsAndResultType]> { string summary = "While operator"; string description = [{ @@ -534,7 +569,7 @@ def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> { } def HLO_AllReduceOp : HLO_Op<"all_reduce", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { + [SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { let arguments = (ins HLO_Tensor:$operand, @@ -561,7 +596,7 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all", } def HLO_ReduceOp: HLO_Op<"reduce", [ - NoSideEffect, + RecursiveSideEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp"> ]>, BASE_HLO_ReduceOp { @@ -619,23 +654,18 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { } def HLO_CompareOp: HLO_Op<"compare", - [NoSideEffect, SameOperandsElementType]>, BASE_HLO_CompareOp { + [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>, + BASE_HLO_CompareOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - OptionalAttr:$broadcast_dimensions, HLO_ComparisonDirectionAttr:$comparison_direction ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions, " - "StringAttr comparison_direction" - >]; let results = (outs HLO_PredTensor); let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + "StringAttr comparison_direction" >]; } @@ -781,23 +811,6 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let hasCustomHLOConverter = 1; } -def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor", - [SameOperandsElementType, NoSideEffect]> { - string summary = "Converts a sequence of scalars into a 1d tensor."; - - string description = [{ - This is a useful operation that is currently missing in Standard. Used to - compute shape arguments to dynamic operations. - }]; - - let arguments = (ins Variadic:$scalars); - let results = (outs HLO_DimensionTensor); - - // Cannot be exported to legacy formats. - let hasCustomHLOConverter = 1; - let hasCanonicalizer = 1; -} - def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [NoSideEffect]> { string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; @@ -1047,8 +1060,8 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, } def HLO_MapOp: HLO_Op<"map", - [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape, - SingleBlockImplicitTerminator<"ReturnOp">]>, + [RecursiveSideEffects, SameOperandsElementType, + SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">]>, BASE_HLO_MapOp { let arguments = (ins Variadic:$operands, @@ -1063,13 +1076,13 @@ def HLO_ReshapeOp: HLO_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp { let arguments = (ins HLO_Tensor:$operand); - let results = (outs HLO_Tensor); + let results = (outs HLO_StaticShapeTensor); let hasFolder = 1; let hasCustomHLOConverter = 1; } -def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", []> { +def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> { let summary = "Reshape a tensor to a given, possibly dynamic, shape."; let description = [{ Reshapes `operand` to `output_shape`. @@ -1097,7 +1110,8 @@ def ScatterDimensionNumbers : StructAttr<"ScatterDimensionNumbers", HLO_Dialect, let description = "Structure of dimension information for scatter"; } -def HLO_ScatterOp: HLO_Op<"scatter", [NoSideEffect]>, BASE_HLO_ScatterOp { +def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, + BASE_HLO_ScatterOp { let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$scatter_indices, @@ -1126,7 +1140,7 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods, BASE_HLO_SelectAndScatterOp { + [RecursiveSideEffects]>, BASE_HLO_SelectAndScatterOp { let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$source, @@ -1153,7 +1167,7 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, let results = (outs HLO_Tensor); } -def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp { +def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { let arguments = (ins Variadic:$operands, DefaultValuedAttr:$dimension, @@ -1205,7 +1219,7 @@ def HLO_PadOp: HLO_Op<"pad", let hasCustomHLOConverter = 1; } -def HLO_TraceOp: HLO_Op<"trace", [NoSideEffect]>, BASE_HLO_TraceOp { +def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { let arguments = (ins HLO_Tensor:$operand, StrAttr:$tag @@ -1239,7 +1253,7 @@ def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", } def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ - NoSideEffect, + RecursiveSideEffects, SingleBlockImplicitTerminator<"ReturnOp"> ]>, BASE_HLO_ReduceWindowOp { @@ -1270,7 +1284,7 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ // TODO(hinsu): Implement custom printer and parser. } -def HLO_ReturnOp : HLO_Op<"return", [Terminator]> { +def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> { let summary = [{ The `hlo.return` operation terminates a region and returns values. }]; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index c087ffd1f40..bad1bf16ec3 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -62,9 +62,11 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; +def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; + // Dynamic representation of a shape vector as a tensor. def HLO_DimensionTensor : ShapedContainerType< - [Index, HLO_Pred, HLO_Int], + [HLO_DimensionValue], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, "a 1D tensor of dimensions">; @@ -150,15 +152,6 @@ class BASE_HLO_ClzOp { }]; } -class BASE_HLO_ComplexOp { - string summary = "Complex operator"; - - string description = [{ - Performs element-wise conversion of a pair of real and imaginary values to - a complex value. - }]; -} - class BASE_HLO_ConvertOp { string summary = "Convert operator"; @@ -400,6 +393,15 @@ class BASE_HLO_AddOp { }]; } +class BASE_HLO_ComplexOp { + string summary = "Complex operator"; + + string description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; +} + class BASE_HLO_DivOp { string summary = "Division operator"; @@ -553,6 +555,29 @@ class BASE_HLO_XorOp { }]; } +//===----------------------------------------------------------------------===// +// XLA control flow related op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_CaseOp { + string summary = "Switch-Case operator"; + + string description = [{ + Returns the result of executing `branches[index]`. If + `index` is < 0 or >= N, then `branches[N-1] is executed as + the default branch. + + Each branch `branches[b]` must take in a single argument of same type as + `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type + of the returned value of each branch must be the same. + + Note that only one of the branches will be executed depending on the value + of index. + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; + +} + //===----------------------------------------------------------------------===// // XLA parallelism related op definitions. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index 079169e9c5c..03e41f6432c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -35,22 +35,33 @@ mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, mlir::Value y, bool allow_empty = true); -/// Get a constant splat for the given value type. +// Get a constant splat for the given value of type. Requires value to be of +// type static shaped RankedTensorType. +template +static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) { + Type element_ty = getElementTypeOrSelf(ty); + + if (element_ty.isSignlessInteger()) + return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant)); + + if (element_ty.isa()) + return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant)); + + if (auto complex_ty = element_ty.dyn_cast()) { + auto complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) + return DenseElementsAttr::get(ty, + static_cast>(constant)); + if (complex_element_ty.isF64()) + return DenseElementsAttr::get( + ty, static_cast>(constant)); + } + llvm_unreachable("unhandled element type"); +} + template static ElementsAttr getSplat(Builder* b, Value val, T constant) { - auto valType = val.getType().cast(); - auto valElementType = getElementTypeOrSelf(val.getType()); - - // Handle integer elements. - Attribute elementAttr; - if (valElementType.isSignlessInteger()) - elementAttr = b->getIntegerAttr(valElementType, constant); - else if (valElementType.isa()) - elementAttr = b->getFloatAttr(valElementType, constant); - else - llvm_unreachable("unhandled element type"); - - return DenseElementsAttr::get(valType, elementAttr); + return getSplat(b, val.getType().cast(), constant); } // Returns DenseElementsAttr of rank zero with the given element type and the diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc index 680a73e49c5..6f9b39377af 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc @@ -55,6 +55,36 @@ XlaLhloDialect::XlaLhloDialect(MLIRContext *context) >(); } +//===----------------------------------------------------------------------===// +// StaticMemRefCastOp +//===----------------------------------------------------------------------===// + +Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); } + +static LogicalResult Verify(StaticMemRefCastOp op) { + if (!op.operand().getType().cast().hasStaticShape()) + return op.emitOpError("operand must have static shape"); + if (!op.getType().hasStaticShape()) + return op.emitOpError("result must have static shape"); + return success(); +} + +//===----------------------------------------------------------------------===// +// DynamicMemRefCastOp +//===----------------------------------------------------------------------===// + +Value DynamicMemRefCastOp::getViewSource() { + return *getODSOperands(0).begin(); +} + +static LogicalResult Verify(DynamicMemRefCastOp op) { + // Check if `sizes` and `strides` args are compatible with the result type. + if (op.sizes().size() != op.getType().getRank()) + return op.emitOpError( + "`sizes` args count must be equal to the rank of the output memref"); + return success(); +} + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h index 190c5ff832d..3827e8a7a4e 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -27,7 +27,8 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ViewLikeInterface.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 3abd117f570..d9f3648bb09 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -19,7 +19,8 @@ limitations under the License. #define LHLO_OPS include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ViewLikeInterface.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" def LHLO_Dialect : Dialect { @@ -92,10 +93,20 @@ def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp; def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp; +def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp; def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; +def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp; def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp; @@ -106,27 +117,6 @@ def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp; def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; -//===----------------------------------------------------------------------===// -// XLA complex unary elementwise op definitions. -//===----------------------------------------------------------------------===// -// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions - -def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { - let arguments = (ins Arg:$lhs, - Arg:$rhs, - Arg:$output); -} - -def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { - let arguments = (ins Arg:$input, - Arg:$output); -} - -def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { - let arguments = (ins Arg:$input, - Arg:$output); -} - //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// @@ -144,6 +134,12 @@ class LHLO_BinaryElementwiseOp traits> : def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp; +def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { + let arguments = (ins Arg:$lhs, + Arg:$rhs, + Arg:$output); +} + def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp; def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp; @@ -201,6 +197,19 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ let regions = (region SizedRegion<1>:$body); } +def LHLO_CaseOp: LHLO_Op<"case", [ + SingleBlockImplicitTerminator<"TerminatorOp"> + ]>, BASE_HLO_CaseOp { + + let arguments = (ins + Arg:$index, + Arg, "", [MemRead]>:$branch_operands, + Arg:$out + ); + + let regions = (region VariadicRegion>:$branches); +} + //===----------------------------------------------------------------------===// // XLA tuple op definitions. //===----------------------------------------------------------------------===// @@ -254,6 +263,96 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { ); } +//===----------------------------------------------------------------------===// +// StaticMemRefCastOp +//===----------------------------------------------------------------------===// + +def HLO_StaticMemRefCastOp: Op]> { + let summary = "static memref cast operation"; + let description = [{ + Allows to modify the offset, sizes and strides of a statically shaped memref. + + Example: + ```mlir + %buf_transformed = + xla_lhlo.static_memref_cast %buf + : memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]> + + // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and + // offset 2. + ``` + }]; + + let arguments = (ins Arg:$operand); + let results = (outs Res:$result); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, MemRefType resultType, " # + "Value operand", [{ + result.addOperands(operand); + result.types.push_back(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let verifier = [{ return Verify(*this); }]; + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// DynamicMemRefCastOp +//===----------------------------------------------------------------------===// + +def HLO_DynamicMemRefCastOp: Op]> { + let summary = "dynamic memref cast operation"; + let description = [{ + Change sizes and strides of a memref using the values computed in runtime. + + Example: + ```mlir + %buf_transformed = + xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y] + : memref -> memref + // The result of the op is a type-erased memref with `[%size_X, %size_Y]` + // shape and `[%step_X, %step_Y]` strides. The offset will be inherited + // from the input. + ``` + }]; + + let arguments = (ins + Arg:$operand, + Variadic:$sizes, + Variadic:$strides + ); + let results = (outs Res:$result); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, MemRefType resultType, " # + "Value operand, ValueRange sizes, ValueRange strides", [{ + result.addOperands(operand); + result.addOperands(sizes); + result.addOperands(strides); + result.types.push_back(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let verifier = [{ return Verify(*this); }]; + let assemblyFormat = [{ + $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->` + type($result) + }]; +} + //===----------------------------------------------------------------------===// // XLA Other op definitions. //===----------------------------------------------------------------------===// @@ -436,6 +535,10 @@ def TerminatorOp : let description = [{ Terminator operation for the LHLO dialect. }]; + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result, ValueRange operands", + [{ build(b, result, llvm::None, operands, llvm::None); }] + >]; } #endif // LHLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 99d1da74fc5..774caab77fb 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -56,6 +56,20 @@ static mlir::DenseIntElementsAttr GetI64ElementsAttr( return mlir::DenseIntElementsAttr::get(ty, mlir_values); } +static mlir::DenseIntElementsAttr ConvertPadding( + absl::Span> padding, + mlir::Builder* builder) { + llvm::SmallVector elements; + elements.reserve(padding.size() * 2); + for (const auto& vals : padding) { + elements.push_back(vals.first); + elements.push_back(vals.second); + } + auto ty = mlir::RankedTensorType::get( + {static_cast(padding.size()), 2}, builder->getIntegerType(64)); + return mlir::DenseIntElementsAttr::get(ty, elements); +} + MlirHloBuilder::~MlirHloBuilder() = default; StatusOr MlirHloBuilder::MakeXlaOp(mlir::Value val) { @@ -79,6 +93,31 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +StatusOr MlirHloBuilder::ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + mlir::ArrayAttr config_attr; + if (precision_config) + config_attr = ConvertPrecisionConfig(precision_config, &builder_); + auto op = builder_.create( + loc_, ty, GetValue(lhs), GetValue(rhs), + GetI64ElementsAttr(window_strides, &builder_), + ConvertPadding(padding, &builder_), + GetI64ElementsAttr(lhs_dilation, &builder_), + GetI64ElementsAttr(rhs_dilation, &builder_), + ConvertConvDimensionNumbers(dimension_numbers, &builder_), + builder_.getI64IntegerAttr(feature_group_count), + builder_.getI64IntegerAttr(batch_group_count), config_attr); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( @@ -170,7 +209,6 @@ StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, shape, builder_)); auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), - /*broadcast_dimensions=*/mlir::DenseIntElementsAttr(), builder_.getStringAttr(ComparisonDirectionToString(direction))); return MakeXlaOp(op.getResult()); } @@ -243,6 +281,28 @@ StatusOr MlirHloBuilder::SliceInternal( GetI64ElementsAttr(strides, &builder_))); } +StatusOr MlirHloBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValues(start_indices), + GetI64ElementsAttr(slice_sizes, &builder_))); +} + +StatusOr MlirHloBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValue(update), + GetValues(start_indices))); +} + StatusOr MlirHloBuilder::PadInternal( const Shape& shape, XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) { diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index dbcb6856971..fc5baaee44d 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -110,6 +110,16 @@ class MlirHloBuilder : public XlaBuilder { private: XlaOp ConstantLiteral(const LiteralSlice& literal) override; + StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) override; + StatusOr TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) override; @@ -165,6 +175,14 @@ class MlirHloBuilder : public XlaBuilder { absl::Span limit_indices, absl::Span strides) override; + StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) override; + + StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) override; + StatusOr PadInternal(const Shape& shape, XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) override; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index a1fb6b559e3..3c670ef0c6e 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -602,13 +602,12 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { return success(); } -LogicalResult ExportXlaOp(ScalarsToDimensionTensorOp op, - OpLoweringContext ctx) { +LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { // This op has no expression in the legacy export format. return failure(); } -LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) { // This op has no expression in the legacy export format. return failure(); } @@ -618,7 +617,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return failure(); } -LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { xla::XlaComputation true_branch; xla::XlaComputation false_branch; auto& value_map = *ctx.values; @@ -636,6 +635,33 @@ LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { + llvm::DenseMap& value_map = *ctx.values; + OperandRange operands = op.branch_operands(); + MutableArrayRef branches = op.branches(); + llvm::SmallVector branch_operands(branches.size()); + std::vector computations(branches.size()); + std::vector computations_p(branches.size()); + + for (unsigned i = 0; i < branches.size(); ++i) { + branch_operands[i] = value_map[operands[i]]; + computations_p[i] = &computations[i]; + if (failed(ctx.converter->LowerRegionAsComputation(&branches[i], + computations_p[i]))) + return failure(); + } + xla::XlaOp result = + xla::Conditional(value_map[op.index()], computations_p, branch_operands); + if (op.getNumResults() == 1) { + value_map[op.getResult(0)] = result; + } else { + for (auto item : llvm::enumerate(op.getResults())) { + value_map[item.value()] = xla::GetTupleElement(result, item.index()); + } + } + return success(); +} + LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) { return failure(); } @@ -933,6 +959,8 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex) case xla::PrimitiveType::F16: { llvm::SmallVector values; values.reserve(attr.getNumElements()); @@ -984,6 +1012,21 @@ LogicalResult ConvertToHloModule::Lower( return LowerFunctionCall(&call_op, builder, &value_map); } + if (auto op = dyn_cast(inst)) { + Value operand = op.getOperand(); + auto ty = operand.getType().dyn_cast(); + // If this was a cast from a static shaped tensors, then it is a noop for + // export to HLO and we can use the operand. + if (!ty || !ty.hasStaticShape()) { + inst->emitOpError() + << "requires static shaped operand for HLO translation"; + return failure(); + } + + value_map[op.getResult()] = value_map[operand]; + return success(); + } + // TODO(jpienaar): This doesn't support layouts yet. if (matchPattern(inst, m_Constant(&const_attr))) { auto literal_or = CreateLiteralFromAttr(const_attr); diff --git a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir b/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir deleted file mode 100644 index ad007d0eb50..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/buffer-assignment.mlir +++ /dev/null @@ -1,260 +0,0 @@ -// RUN: tf-opt -test-buffer-assignment -allow-unregistered-dialect -split-input-file %s | FileCheck %s -dump-input-on-failure - -// CHECK-LABEL: func @func_signature_conversion -func @func_signature_conversion(%arg0: tensor<4x8xf32>) { - return -} -// CHECK: ({{.*}}: memref<4x8xf32>) { - -// ----- - -// CHECK-LABEL: func @non_void_to_void_return_op_converter -func @non_void_to_void_return_op_converter(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - return %arg0 : tensor<4x8xf32> -} -// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]<[[RANK:.*]]>, %[[RESULT:.*]]: [[TYPE]]<[[RANK]]>) { -// CHECK-NEXT: "buffer_assignment_test.copy"(%[[ARG0]], %[[RESULT]]) -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @func_and_block_signature_conversion -func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{ - cond_br %cond, ^bb1, ^bb2 - ^bb1: - br ^exit(%arg0 : tensor<2xf32>) - ^bb2: - br ^exit(%arg0 : tensor<2xf32>) - ^exit(%arg2: tensor<2xf32>): - return %arg1 : tensor<4x4xf32> -} -// CHECK: (%[[ARG0:.*]]: [[ARG0_TYPE:.*]], %[[COND:.*]]: i1, %[[ARG1:.*]]: [[ARG1_TYPE:.*]], %[[RESULT:.*]]: [[RESULT_TYPE:.*]]) { -// CHECK: br ^[[EXIT_BLOCK:.*]](%[[ARG0]] : [[ARG0_TYPE]]) -// CHECK: br ^[[EXIT_BLOCK]](%[[ARG0]] : [[ARG0_TYPE]]) -// CHECK: ^[[EXIT_BLOCK]](%{{.*}}: [[ARG0_TYPE]]) -// CHECK-NEXT: "buffer_assignment_test.copy"(%[[ARG1]], %[[RESULT]]) -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @condBranch -func @condBranch(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - cond_br %cond, ^bb1, ^bb2 - ^bb1: - br ^exit(%arg0 : tensor<2xf32>) - ^bb2: - %1 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - br ^exit(%1 : tensor<2xf32>) - ^exit(%arg1: tensor<2xf32>): - return %arg1 : tensor<2xf32> - -} -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() -// CHECK-NEXT: cond_br -// CHECK: "buffer_assignment_test.copy -// CHECK-NEXT: dealloc %[[ALLOC]] -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @emptyUsesValue -func @emptyUsesValue(%arg0: memref<4xf32>) { - %0 = alloc() : memref<4xf32> - return -} -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() -// CHECK-NEXT: dealloc %[[ALLOC]] -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @criticalEdge -func @criticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) - ^bb1: - %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - br ^exit(%0 : tensor<2xf32>) - ^exit(%arg1: tensor<2xf32>): - return %arg1 : tensor<2xf32> -} -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() -// CHECK-NEXT: cond_br -// CHECK: "buffer_assignment_test.copy -// CHECK-NEXT: dealloc %[[ALLOC]] -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @invCriticalEdge -func @invCriticalEdge(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - cond_br %cond, ^bb1, ^exit(%arg0 : tensor<2xf32>) - ^bb1: - br ^exit(%0 : tensor<2xf32>) - ^exit(%arg1: tensor<2xf32>): - return %arg1 : tensor<2xf32> -} -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK: "buffer_assignment_test.copy -// CHECK-NEXT: dealloc %[[ALLOC]] -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @ifElse -func @ifElse(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), - ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) - ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): - br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) - ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): - br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) - ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - %1 = "buffer_assignment_test.unary"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> - return %1 : tensor<2xf32> -} -// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK-NEXT: dealloc %[[FIRST_ALLOC]] -// CHECK-NEXT: "buffer_assignment_test.copy -// CHECK-NEXT: dealloc %[[SECOND_ALLOC]] -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @ifElseNoUsers -func @ifElseNoUsers(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), - ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) - ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): - br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) - ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): - br ^exit(%arg3, %arg4 : tensor<2xf32>, tensor<2xf32>) - ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - return %arg0 : tensor<2xf32> -} -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK: "buffer_assignment_test.copy -// CHECK-NEXT: dealloc %[[ALLOC]] -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @ifElseNested -func @ifElseNested(%cond : i1, %arg0 : tensor<2xf32>) -> tensor<2xf32>{ - %0 = "buffer_assignment_test.unary"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - cond_br %cond, ^bb1(%arg0, %0: tensor<2xf32>, tensor<2xf32>), - ^bb2(%0, %arg0: tensor<2xf32>, tensor<2xf32>) - ^bb1(%arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>): - br ^exit(%arg1, %arg2 : tensor<2xf32>, tensor<2xf32>) - ^bb2(%arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>): - cond_br %cond, ^bb3(%arg3 : tensor<2xf32>), ^bb4(%arg4 : tensor<2xf32>) - ^bb3(%arg7 : tensor<2xf32>): - br ^exit(%arg7, %arg3 : tensor<2xf32>, tensor<2xf32>) - ^bb4(%arg8 : tensor<2xf32>): - br ^exit(%arg3, %arg8 : tensor<2xf32>, tensor<2xf32>) - ^exit(%arg5 : tensor<2xf32>, %arg6 : tensor<2xf32>): - %1 = "buffer_assignment_test.unary"(%arg5) : (tensor<2xf32>) -> tensor<2xf32> - return %1 : tensor<2xf32> -} -// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK-NEXT: dealloc %[[FIRST_ALLOC]] -// CHECK-NEXT: "buffer_assignment_test.copy -// CHECK-NEXT: dealloc %[[SECOND_ALLOC]] -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @redundantOperations -func @redundantOperations(%arg0: tensor<4xf32>) { - %1 = "buffer_assignment_test.unary"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> - %2 = "buffer_assignment_test.unary"(%1) : (tensor<4xf32>) -> tensor<4xf32> - return -} -// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK-NEXT: "buffer_assignment_test.unary_lowered" -// CHECK-NEXT: dealloc -// CHECK-NEXT: dealloc -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @moving_alloc_and_inserting_missing_dealloc -func @moving_alloc_and_inserting_missing_dealloc(%cond : i1, %arg0 : memref<2xf32>, %arg1: memref<2xf32>){ - cond_br %cond, ^bb1, ^bb2 - ^bb1: - %0 = alloc() : memref<2xf32> - "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () - br ^exit(%0 : memref<2xf32>) - ^bb2: - - %1 = alloc() : memref<2xf32> - "buffer_assignment_test.unary_lowered"(%arg0, %1) : (memref<2xf32>, memref<2xf32>) -> () - br ^exit(%1 : memref<2xf32>) - ^exit(%arg2: memref<2xf32>): - "bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () - return -} -// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK: "bufer_assignment_test.copy" -// CHECK-NEXT: dealloc -// CHECK-NEXT: dealloc -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @moving_invalid_dealloc_op_complex -func @moving_invalid_dealloc_op_complex(%cond : i1, %arg0 : memref<2xf32>, %arg1: memref<2xf32>){ - cond_br %cond, ^bb1, ^bb2 - ^bb1: - br ^exit(%arg0 : memref<2xf32>) - ^bb2: - %1 = alloc() : memref<2xf32> - "buffer_assignment_test.unary_lowered"(%arg0, %1) : (memref<2xf32>, memref<2xf32>) -> () - dealloc %1 : memref<2xf32> - br ^exit(%1 : memref<2xf32>) - ^exit(%arg2: memref<2xf32>): - "bufer_assignment_test.copy"(%arg2, %arg1) : (memref<2xf32>, memref<2xf32>) -> () - return -} -// CHECK-NEXT: %[[ALLOC:.*]] = alloc() -// CHECK: bufer_assignment_test.copy -// CHECK-NEXT: dealloc -// CHECK-NEXT: return - -// ----- - -// CHECK-LABEL: func @inserting_missing_dealloc_simple -func @inserting_missing_dealloc_simple(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ - %0 = alloc() : memref<2xf32> - "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () - "bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () - return -} -// CHECK: bufer_assignment_test.copy -// CHECK-NEXT: dealloc - -// ----- - -// CHECK-LABEL: func @moving_invalid_dealloc_op -func @moving_invalid_dealloc_op(%arg0 : memref<2xf32>, %arg1: memref<2xf32>){ - %0 = alloc() : memref<2xf32> - "buffer_assignment_test.unary_lowered"(%arg0, %0) : (memref<2xf32>, memref<2xf32>) -> () - dealloc %0 : memref<2xf32> - "bufer_assignment_test.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () - return -} -// CHECK: bufer_assignment_test.copy -// CHECK-NEXT: dealloc \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 30255586002..6b9ed3e463a 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -45,6 +45,60 @@ func @multiply_scalar_fold() -> tensor<4xi64> { return %2 : tensor<4xi64> } +// CHECK-LABEL: divide_scalar_fold +func @divide_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<7> : tensor<4xi64> + %1 = xla_hlo.constant dense<5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<1> + %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: divide_fold_float +func @divide_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> + %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: max_scalar_fold +func @max_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<7> : tensor<4xi64> + %1 = xla_hlo.constant dense<5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<7> + %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: max_fold_float +func @max_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]> + %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: min_scalar_fold +func @min_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<7> : tensor<4xi64> + %1 = xla_hlo.constant dense<-5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<-5> + %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: min_fold_float +func @min_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> + %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + // CHECK-LABEL: concatenate_noop func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> @@ -281,6 +335,14 @@ func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> } +// CHECK-LABEL: @dynamic_iota_is_static +func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" + // CHECK: return [[RESULT]] + %0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + // CHECK-LABEL: @iota_not_lowered_to_constant func @iota_not_lowered_to_constant() -> tensor<4xi32> { // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" @@ -297,16 +359,6 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> } -// CHECK-LABEL: @extract_scalars_to_tensor -// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 -func @extract_scalars_to_tensor(%arg0: i32, %arg1: i32) -> i32 { - %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32> - %1 = constant 0 : index - %2 = extract_element %0[%1] : tensor<2xi32> - // CHECK: return %[[ARG0]] - return %2 : i32 -} - // CHECK-LABEL: func @fold_copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { @@ -387,8 +439,8 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor< return %0 : tensor<4x1xf32> } -// CHECK-LABEL: do_not_dce_while -func @do_not_dce_while(%arg0: tensor) -> tensor { +// CHECK-LABEL: do_not_dce_while_with_outfeed +func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { // CHECK: xla_hlo.while %0 = "xla_hlo.while"(%arg0) ( { ^bb0(%arg1: tensor): @@ -404,3 +456,19 @@ func @do_not_dce_while(%arg0: tensor) -> tensor { return %arg0 : tensor } + +// CHECK-LABEL: dce_while_without_side_effect +func @dce_while_without_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: xla_hlo.while + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token + "xla_hlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir index ce0243e416c..d67a7d09f7c 100644 --- a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir @@ -6,8 +6,8 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]]) // CHECK: return %[[EXTENTS]] diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir index 2bc1e0c6852..7194f7034b5 100644 --- a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -14,8 +14,8 @@ func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} @@ -31,8 +31,8 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor @@ -48,8 +48,8 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 262533bbf08..f4b9fa206f2 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure +// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -13,33 +13,42 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { + return %arg0 : tensor<4xf32> +} +// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) +// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + +// ----- + // CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) - // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) %5 = xla_hlo.multiply %2, %4 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) - // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> - // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () - // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> return %5 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } +// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) +// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> +// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () +// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // ----- @@ -47,20 +56,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) - // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> - // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () "xla_lhlo.terminator"() : () -> () @@ -354,7 +363,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64> + // CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index @@ -376,7 +385,7 @@ func @tanh_dyn(%arg0: tensor) { // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 - // CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64> + // CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index @@ -386,3 +395,15 @@ func @tanh_dyn(%arg0: tensor) { // CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return } + +// ----- + +// CHECK-LABEL: func @dot +func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { +// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], +// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]]) +// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () + %dot = "xla_hlo.dot"(%arg0, %arg0) + : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %dot : tensor<1024x1024xf32> + } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index aa949a01388..a27bf2cff79 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -530,3 +530,28 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64): // CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { + %result = "xla_hlo.reverse"(%input) { + dimensions = dense<1> : tensor<1xi64> + } : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %result : tensor<2x3xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir index cda1dc481a7..6a2b68adac3 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir @@ -8,7 +8,9 @@ // CHECK-SAME: ) { func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // The only expected instruction is a copy from the input into the output. - // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][][] : memref<16xi8> to memref<2x2xf32> + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C02:.*]] = constant 0 : index + // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C02]]][] : memref<16xi8> to memref<2x2xf32> // CHECK: xla_lhlo.copy // CHECK-SAME: %[[ARG0]], %[[OUTPUT]] return %value : tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir index 83c3f765dc3..83880bc8ce9 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir @@ -35,7 +35,7 @@ func @conditional(%arg0: tensor) -> tensor { // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor), ^bb2(%arg0 : tensor) - %1 = "xla_hlo.conditional"(%0, %arg0, %arg0) ( { + %1 = "xla_hlo.if"(%0, %arg0, %arg0) ( { ^bb0(%arg1: tensor): // CHECK: ^bb1([[VAL2:%.+]]: tensor): @@ -131,7 +131,7 @@ func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, % // CHECK: ^[[EXIT]](%6: tensor): // CHECK: return %6 : tensor // CHECK: } - %1 = "xla_hlo.conditional"(%pred, %arg0, %arg1) ( { + %1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( { ^then_entry(%arg2: tensor): br ^then_succ(%arg2: tensor) ^then_succ(%0: tensor): diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index 08df9fd3808..3605e2a0d5c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -7,8 +7,8 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_basic // CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> -// CHECK: [[LHSSHAPE:%.*]] = "shape.shape_of"([[LHS]]) : (tensor<1x4x2xf32>) -> !shape.shape -// CHECK: [[RHSSHAPE:%.*]] = "shape.shape_of"([[RHS]]) : (tensor<3x2x4xf32>) -> !shape.shape +// CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32> +// CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32> // CHECK: [[CM2:%.*]] = constant -2 : i32 // CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) // CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) @@ -86,8 +86,8 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2 // CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]]) // CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]]) // CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) -// CHECK: "shape.shape_of"([[LHSCONJ]]) -// CHECK: "shape.shape_of"([[RHSCONJ]]) +// CHECK: shape.shape_of [[LHSCONJ]] +// CHECK: shape.shape_of [[RHSCONJ]] %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> return %0 : tensor<5x4xcomplex> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir new file mode 100644 index 00000000000..c114b8c50a5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -0,0 +1,334 @@ +// Note that binary elementwise tests are run with chlo legalization enabled +// (unlike the rest), since this is the primary use case for such ops and +// verification of shapes and broadcasts is desired. +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s --dump-input-on-failure + +//===----------------------------------------------------------------------===// +// Binary op legalizations. +// Most of these expand from the same pattern. Full semantics are +// verified for tf.Add and pattern application only for the rest. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @add +func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> + // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> + %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %1: tensor<2xi32> +} + +// CHECK-LABEL: func @broadcast_add +// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check +// patterns unambiguous and more interesting (once broadcastable trait is +// fixed upstream). +func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0: tensor<1x2xi32> +} + +// CHECK-LABEL: func @broadcast_multi_dim_add +// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream +// broadcastable bug is fixed (helps make the CHECK matching unambiguous) +func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + return %0: tensor<4x4x4x4xi32> +} + +// CHECK-LABEL: func @add_dynamic +func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %4, %5 : tensor + %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @div +func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_left +func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @div_unranked +func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { + // CHECK: tf.Div + %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @maximum +func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @minimum +func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @mul +func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @real_div +func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @sub +func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_right +func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @shift_right_unsigned +func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> +} + +// CHECK-LABEL: func @broadcast_shift_right_unsigned +func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> + return %0 : tensor<2x4xui8> +} + +// CHECK-LABEL: func @and +func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @and_unranked +func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { + // CHECK: tf.LogicalAnd + %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @or +func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @bitwise_or +func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_and +func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @pow +func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: xla_hlo.power + %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0: tensor<2xf32> +} + +//===----------------------------------------------------------------------===// +// Equality op legalizations. +// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are +// verified for tf.Equal and pattern application only for tf.NotEqual +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @equal +func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @equal_dynamic +func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @equal_broadcast +func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error +func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_broadcastable +func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @equal_incompatible_shape_dynamic +func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic +func @equal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_unranked +func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Equal" + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @notequal +func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} + %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +//===----------------------------------------------------------------------===// +// Compare op legalizations. +// These expand from the same pattern. Full semantics are checked for +// tf.Greater. Others just check that the pattern applied. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @greater +func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @broadcast_greater +func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @greater_dynamic +func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + +// CHECK-LABEL: func @greater_uranked +func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Greater" + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @greater_equal +func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} + %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less +func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} + %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less_equal +func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} + %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index 2984ba46993..b3307a8f52a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -1,12 +1,12 @@ // RUN: tf-opt -xla-legalize-tf-control-flow %s | FileCheck %s --dump-input-on-failure -// CHECK-LABEL: @conditional -func @conditional(%arg0: tensor, %arg1: tensor) -> (tensor) +// CHECK-LABEL: @if +func @if(%arg0: tensor, %arg1: tensor) -> (tensor) attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { // CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor // CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1) - // CHECK: [[VAL2:%.+]] = "xla_hlo.conditional"([[VAL0]], [[VAL1]], [[VAL1]]) ( { + // CHECK: [[VAL2:%.+]] = "xla_hlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( { // CHECK: ^bb0(%arg2: tuple, tensor>): // CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32} // CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32} @@ -40,7 +40,52 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { return %0 : tensor } -// CHECK-LABEL: @while + +// CHECK-LABEL: func @case +// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor, %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> (tensor, tensor) +func @case(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor, tensor, tensor) -> (tensor, tensor) + // CHECK: %[[TUPLE_INPUT:.*]] = "xla_hlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> tuple, tensor> + // CHECK: %[[CASE:.*]]:2 = "xla_hlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( { + // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[CALL_EXP:.*]]:2 = call @exponential(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor, tensor) -> (tensor, tensor) + // CHECK: "xla_hlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor, tensor) -> () + // CHECK: }, { + // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[CALL_LOG:.*]]:2 = call @log(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor, tensor) -> (tensor, tensor) + // CHECK: "xla_hlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor, tensor) -> () + // CHECK: }, { + // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple, tensor>) -> tensor + // CHECK: %[[CALL_FLOOR:.*]]:2 = call @floor(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor, tensor) -> (tensor, tensor) + // CHECK: "xla_hlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor, tensor) -> () + // CHECK: }) : (tensor, tuple, tensor>, tuple, tensor>, tuple, tensor>) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +// CHECK: return %[[CASE]]#0, %[[CASE]]#1 : tensor, tensor +} + +func @exponential(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor + return %0, %arg1 : tensor, tensor +} + +func @log(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "xla_hlo.log"(%arg0) : (tensor) -> tensor + return %0, %arg1 : tensor, tensor +} + +func @floor(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + return %0, %arg1 : tensor, tensor +} + + +// CHECK-LABEL: func @while func @while(%arg0: tensor {tf_saved_model.index_path = [0]}) -> (tensor {tf_saved_model.index_path = []}) attributes {tf._input_shapes = ["tfshape$"]} { // CHECK: [[VAL0:%.+]] = xla_hlo.constant dense<0> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir index d2b4d269fef..0660af4ed1c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir @@ -1,22 +1,24 @@ // RUN: tf-opt %s -xla-legalize-tf -split-input-file -verify-diagnostics +// expected-error@below{{The following operations cannot be legalized: tf.NoOp (count: 1); tf_executor.fetch (count: 1); tf_executor.graph (count: 1); tf_executor.island (count: 1); tf_executor.yield (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} func @tf_executor_graph_op() { - // expected-error@+1 {{failed to legalize operation 'tf_executor.graph'}} tf_executor.graph { %0 = tf_executor.island { + // expected-error@+1 {{'tf.NoOp' op is not legalizable}} "tf.NoOp"() {} : () -> () tf_executor.yield } tf_executor.fetch } return - } // ----- +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // expected-error@+1 {{failed to legalize operation 'tf.OpA'}} + // expected-error@+1 {{'tf.OpA' op is not legalizable}} %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -27,3 +29,16 @@ func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } + +// ----- + +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1); tf.OpB (count: 2). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} +func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // expected-error@+1 {{'tf.OpA' op is not legalizable}} + %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %2: tensor<2xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 01398eb7314..e8d5cfe997d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -163,6 +163,30 @@ func @truncated_normal() -> tensor<2x2xf32> { return %1 : tensor<2x2xf32> } +// CHECK-LABEL: dynamic_update_slice +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32> +func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> { + + // CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor + + // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor + + // CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) + + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> + return %0: tensor<3x4xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index e15101a165e..6406e2fee48 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1,4 +1,11 @@ -// RUN: tf-opt -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s +// This test runs twice: +// 1. Through FileCheck with chlo legalization disabled since verifying +// that the chlo ops emit produces more useful tests. +// 2. With chlo legalization enabled, verifying diagnostics to pick up any +// issues with the full lowering (can catch some broadcasting corner +// cases which emit with a warning). //===----------------------------------------------------------------------===// // BatchNorm op legalizations. @@ -47,7 +54,7 @@ func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32> // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: xla_hlo.constant - // CHECK: "xla_hlo.multiply"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> return %0#0 : tensor<8x8x8x8xf32> } @@ -68,18 +75,18 @@ func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, // CHECK-DAG: %[[BATCH_VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} // CHECK: %[[FACTOR:.*]] = xla_hlo.constant dense<1.00195694> - // CHECK: %[[CORRECTED_VAR:.*]] = "xla_hlo.multiply"(%[[BATCH_VAR]], %[[FACTOR]]) + // CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] // CHECK-DAG: %[[ALPHA:.*]] = xla_hlo.constant dense<0.199999988> // CHECK-DAG: %[[BETA:.*]] = xla_hlo.constant dense<8.000000e-01> - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = "xla_hlo.multiply"(%[[ALPHA]], %arg3) - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = "xla_hlo.multiply"(%[[BETA]], %[[BATCH_MEAN]]) - // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_hlo.add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3 + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = "xla_hlo.multiply"(%[[ALPHA]], %arg4) - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = "xla_hlo.multiply"(%[[BETA]], %[[CORRECTED_VAR]]) - // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_hlo.add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4 + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> @@ -127,11 +134,12 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -142,10 +150,10 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -185,11 +193,12 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -200,10 +209,11 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -270,11 +280,12 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -285,10 +296,11 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -355,11 +367,12 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor @@ -370,10 +383,11 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -405,280 +419,41 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_NCHW func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_dynamic func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor return %0 : tensor } //===----------------------------------------------------------------------===// -// Binary op legalizations. +// DiagPart //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @add -func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> - // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> - %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %1: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_add -func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @broadcast_multi_dim_add -func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} - %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> - return %0: tensor<4x4x4x4xi32> -} - -// CHECK-LABEL: func @div -func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_div -func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @shift_left -func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> - %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @div_dynamic -func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Div"(%arg0, %arg1) : (tensor, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @div_unranked -func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { - // CHECK: tf.Div - %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @maximum -func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @minimum -func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @mul -func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_mul -func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @real_div -func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_real_div -func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @sub -func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_sub -func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @shift_right -func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @broadcast_shift_right -func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - // CHECK: "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - return %0 : tensor<2x4xi32> -} - -// CHECK-LABEL: func @shift_right_unsigned -func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> - return %0 : tensor<4xui8> -} - -// CHECK-LABEL: func @broadcast_shift_right_unsigned -func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> - return %0 : tensor<2x4xui8> -} - -// CHECK-LABEL: func @and -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @and_broadcast -func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.and" - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @and_dynamic -func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - // CHECK-NEXT: "xla_hlo.and" - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @and_unranked -func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { - // CHECK: tf.LogicalAnd - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @or -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @or_broadcast -func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @or_dynamic -func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @bitwise_or -func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_or_broadcast -func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> - return %0: tensor<1x4xi8> -} - -// CHECK-LABEL: func @bitwise_or_dynamic -func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @bitwise_and -func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_and_broadcast -func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> - return %0: tensor<1x4xi8> -} - -// CHECK-LABEL: func @bitwise_and_dynamic -func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @pow -func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: xla_hlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - -// CHECK-LABEL: func @pow_dynamic -func @pow_dynamic(%arg0: tensor) -> tensor { - // CHECK-NEXT: xla_hlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor, tensor) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @diag_part // CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { @@ -698,6 +473,10 @@ func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { return %0: tensor<4x3xf32> } +//===----------------------------------------------------------------------===// +// Einsum. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @einsum func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { // CHECK: xla_hlo.einsum @@ -712,22 +491,26 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { return %0: tensor<2x2xf32> } +//===----------------------------------------------------------------------===// +// FloorDiv and FloorMod. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = "xla_hlo.divide"([[NEG]], [[ABS3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -737,19 +520,19 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_hlo.divide [[NEG]], [[ABS3]] + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -758,7 +541,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-LABEL: func @floordiv_f32 func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = xla_hlo.divide %arg0, %arg0 + // CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0 // CHECK-NEXT: %[[FLOOR:.*]] = "xla_hlo.floor"(%[[DIV]]) // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> @@ -769,7 +552,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: xla_hlo.convert - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: return @@ -779,7 +562,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-LABEL: func @floordiv_f16_broadcast func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -788,7 +571,22 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te // CHECK-LABEL: func @floordiv_dynamic func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: tf.FloorDiv + // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) + // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) + // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) + // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } @@ -802,15 +600,15 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_hlo.add %arg1, [[REM]] + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -819,15 +617,15 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) // CHECK-LABEL: func @floormod_broadcast_denominator func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"(%arg1, [[REM]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -836,7 +634,17 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32 // CHECK-LABEL: func @floormod_dynamic func @floormod_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: tf.FloorMod + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} + // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) + // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } @@ -848,6 +656,10 @@ func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x return %0: tensor<*xi32> } +//===----------------------------------------------------------------------===// +// BroadcastTo. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @broadcast_to func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> @@ -860,190 +672,6 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { return %0 : tensor<16x16x16x16xf32> } -//===----------------------------------------------------------------------===// -// Equality op legalizations. -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @equal -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @equal_dynamic -func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @equal_broadcast -func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error -func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_broadcastable -func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @equal_incompatible_shape_dynamic -func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic -func @equal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_unranked -func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Equal" - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @notequal -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @notequal_dynamic -func @notequal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @notequal_broadcast -func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error -func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable -func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @notequal_incompatible_shape_dynamic -func @notequal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @notequal_incompatible_shape_both_dynamic -func @notequal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -//===----------------------------------------------------------------------===// -// Compare op legalizations. -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_greater -func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @greater_dynamic -func @greater_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor, tensor) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @greater_uranked -func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Greater" - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @greater_equal -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_greater_equal -func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @less -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_less -func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @less_equal -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_less_equal -func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - - //===----------------------------------------------------------------------===// // Complex op legalizations. //===----------------------------------------------------------------------===// @@ -1332,12 +960,12 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<*xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<64x64xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<64x64xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1353,11 +981,11 @@ func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2 // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<*xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<24x48xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1504,7 +1132,8 @@ func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: // CHECK-LABEL:one_hot func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> - // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%arg0, %[[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK: %[[BCAST_ARG0:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> // CHECK: %[[ON_VALUE:.*]] = "xla_hlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> // CHECK: %[[OFF_VALUE:.*]] = "xla_hlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor) -> tensor<3x5xf32> // CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> @@ -1596,6 +1225,44 @@ func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32>) -> return %0, %1 : tensor, tensor } + +//===----------------------------------------------------------------------===// +// ReverseV2 op legalization. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @reverse_func_32 +func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + return %reversed : tensor<5xi32> +} + +// CHECK-LABEL: @reverse_func_64 +func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + return %reversed : tensor<5xi32> +} + +// CHECK-LABEL: @reverse_func_neg +func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { + %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> + + // CHECK: return [[VAL]] : tensor<5x5xi32> + return %reversed : tensor<5x5xi32> +} + //===----------------------------------------------------------------------===// // StatefulPartitionedCall op legalization. //===----------------------------------------------------------------------===// @@ -1631,7 +1298,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -1639,7 +1306,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func @relu_unranked(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1667,8 +1334,8 @@ func @relu6_unranked(%arg0: tensor) -> tensor { func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { // CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> - // CHECK-DAG: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO_SCALAR]]) {comparison_direction = "GT"} : (tensor, tensor) -> tensor<*xi1> - // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<*xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> @@ -1678,27 +1345,6 @@ func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tens // Select op legalizations. //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @select -func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @select_float -func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - -// CHECK-LABEL: func @select_multidimensional -func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> - return %0: tensor<3x2xi32> -} - // CHECK-LABEL: func @selectv2 func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) @@ -1737,6 +1383,14 @@ func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %ar return %0: tensor<2x8x8xi32> } +// CHECK-LABEL: func @selectv2_broadcast_tensor_pred +func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK: "xla_hlo.select"(%[[BROADCAST]], %arg1, %arg2) + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + // CHECK-LABEL: func @selectv2_broadcast_all func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { // CHECK-DAG: %[[BROADCAST_0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> @@ -1778,7 +1432,10 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]] // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. @@ -1790,8 +1447,11 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.divide"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -1800,7 +1460,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Verify intermediate and final shape are correct with dynamic shapes. // CHECK-LABEL: func @dynamic_softmax func @dynamic_softmax(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.divide"({{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor) -> tensor + // CHECK: xla_hlo.divide {{.*}} : tensor %0 = "tf.Softmax"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1826,43 +1486,29 @@ func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { // CHECK: "xla_hlo.reduce" // CHECK: dimensions = dense<3> - // CHECK: "xla_hlo.divide"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: xla_hlo.divide {{.*}} %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> return %0: tensor<2x3x4x5xf16> } //===----------------------------------------------------------------------===// // LogSoftmax op legalizations. +// This just changes the tail of the regular Softmax legalization //===----------------------------------------------------------------------===// // CHECK-LABEL: func @simple_logsoftmax // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - - // Verify reduce op for max computation and its body. - // CHECK-DAG: %[[CASTED_INP:.*]] = "xla_hlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[NEG_INF:.*]] = xla_hlo.constant dense<0xFF800000> : tensor - // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) - // CHECK: xla_hlo.maximum - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> - // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) - - // Verify reduce op for summation and its body. - // CHECK-DAG: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) - // CHECK: xla_hlo.add - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} + // CHECK: %{{.*}} = "xla_hlo.reduce"({{.*}}) + // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"({{.*}}) // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[RESULT:.*]] = "xla_hlo.subtract"(%[[SHIFTED_INP]], %[[LOG]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -2163,16 +1809,41 @@ func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sigmoid func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-DAG: [[R0:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor - // CHECK-DAG: [[R1:%.+]] = "xla_hlo.broadcast"([[R0]]) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor) -> tensor<2xf32> - // CHECK-DAG: [[R2:%.+]] = xla_hlo.multiply %arg0, [[R1]] : tensor<2xf32> - // CHECK-DAG: [[R3:%.+]] = "xla_hlo.tanh"([[R2]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK-DAG: [[R4:%.+]] = xla_hlo.multiply [[R3]], [[R1]] : tensor<2xf32> - // CHECK-DAG: [[R5:%.+]] = xla_hlo.add [[R4]], [[R1]] : tensor<2xf32> + // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor + // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32> + // CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<1xindex> + // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<2xf32> + // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<2xf32> + // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK-DAG: [[R3:%.+]] = xla_hlo.multiply [[R2]], [[HALF]] : tensor<2xf32> + // CHECK-DAG: [[R4:%.+]] = xla_hlo.add [[R3]], [[HALF]] : tensor<2xf32> %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: @sigmoid_complex +func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK: [[R0:%.+]] = xla_hlo.constant dense<(5.000000e-01,0.000000e+00)> : tensor> + // CHECK-NOT: tf.Sigmoid + %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + +// CHECK-LABEL: @sigmoid_unranked +func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor + // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32> + // CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor + // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor<*xf32> + // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<*xf32> + // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-DAG: [[R3:%.+]] = xla_hlo.multiply [[R2]], [[HALF]] : tensor<*xf32> + // CHECK-DAG: [[R4:%.+]] = xla_hlo.add [[R3]], [[HALF]] : tensor<*xf32> + %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + + // CHECK-LABEL: @sigmoid_grad func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32> @@ -2184,6 +1855,17 @@ func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: @sigmoid_grad_complex +func @sigmoid_grad_complex(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xcomplex> + // CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex> + // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xcomplex> + // CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex> + // CHECK: return [[MUL1]] + %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + // CHECK-LABEL: @sin func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -2205,13 +1887,6 @@ func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// CHECK-LABEL: func @round -func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.round_nearest_afz"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - return %0 : tensor<2xf32> -} - // CHECK-LABEL: func @rsqrt func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -2505,6 +2180,18 @@ func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { return %output : tensor<3x2xf32> } +// CHECK-LABEL: dynamic_strided_slice +func @dynamic_strided_slice(%input: tensor) -> tensor { + %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: "tf.StridedSlice" + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + return %output : tensor +} + // CHECK-LABEL: strided_slice_negative_indices func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) @@ -2524,6 +2211,18 @@ func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> return %output : tensor<3x2xf32> } +// CHECK-LABEL: dynamic_strided_slice_negative_indices +func @dynamic_strided_slice_negative_indices(%input: tensor) -> tensor { + %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) + %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) + + // CHECK: tf.StridedSlice + %output = "tf.StridedSlice"(%input, %begin, %end, %strides) + : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + return %output : tensor +} + // CHECK-LABEL: strided_slice_range_clamping func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<0x3xf32> { %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) @@ -2720,10 +2419,10 @@ func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[INDEX2]], %[[ZERO]]) + // CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = xla_hlo.add %[[DIM]], %[[INDEX2]] : tensor + // CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice" @@ -2852,7 +2551,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> // CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<8.000000e+00> : tensor - // CHECK: %[[MEAN:.*]] = "xla_hlo.divide"(%[[REDUCED]], %[[DIVISOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -3156,8 +2855,8 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota" - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[DELTA]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK: "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> return %3 : tensor<5xf32> } @@ -3169,12 +2868,12 @@ func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]] // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]]) // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = xla_hlo.divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[STEP]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> @@ -3469,13 +3168,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> // CHECK: %[[DIM_0:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = xla_hlo.multiply %[[CONST]], %[[DIM_0]] + // CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] // CHECK: %[[DIM_1:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = xla_hlo.multiply %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] // CHECK: %[[DIM_2:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = xla_hlo.multiply %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor // CHECK: return %[[MUL_2]] return %size : tensor @@ -3632,30 +3331,31 @@ func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { // tf.Unpack legalization //===----------------------------------------------------------------------===// -// CHECK-LABEL: @unpack -func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { - // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> - // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// TODO(b/156340000): Re-enable when fixed. +// // C-HECK-LABEL: @unpack +// func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { +// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> +// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// // C-HECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) - // return %[[RES1]], %[[RES2]], %[[RES3]] - return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> -} +// %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) +// // return %[[RES1]], %[[RES2]], %[[RES3]] +// return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> +// } -// CHECK-LABEL: @unpack_dynamic -func @unpack_dynamic(%input: tensor) -> (tensor, tensor) { - // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor - // CHECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor) -> tensor - // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor - // CHECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor) -> tensor +// // C-HECK-LABEL: @unpack_dynamic +// func @unpack_dynamic(%input: tensor) -> (tensor, tensor) { +// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor +// // C-HECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor) -> tensor +// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor) -> tensor +// // C-HECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor) -> tensor - %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor) -> (tensor, tensor) - return %0#0, %0#1 : tensor, tensor -} +// %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor) -> (tensor, tensor) +// return %0#0, %0#1 : tensor, tensor +// } //===----------------------------------------------------------------------===// // tf.UnsortedSegment{Max|Min|Prod|Sum} legalization @@ -3720,11 +3420,11 @@ func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> { - // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5x3xf32> +func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> { + // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32> %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5x3xf32> - return %1 : tensor<16x2x5x3xf32> + %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> + return %1 : tensor<16x2x5xf32> } // CHECK-LABEL: @gather_v2_dynamic @@ -3991,7 +3691,7 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[INDICES1:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[INDICES2:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = xla_hlo.constant dense<1> : tensor - // CHECK: [[NEW_IV:%.*]] = xla_hlo.add [[IV]], [[ONE]] + // CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]] // CHECK: [[NEW_TUPLE:%.*]] = "xla_hlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) // CHECK: "xla_hlo.return"([[NEW_TUPLE]]) // CHECK: }) : (tuple, tensor<4xi32>, tensor<4xi32>>) -> tuple, tensor<4xi32>, tensor<4xi32>> @@ -4060,7 +3760,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> // CHECK: "xla_hlo.return"([[ADD]]) // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = xla_hlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = "xla_hlo.divide"([[REDUCE]], [[COUNT]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> + // CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = "xla_hlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: return [[CONV16]] %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> @@ -4081,6 +3781,41 @@ func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { return %0 : tensor<4x16xf32> } +// CHECK-LABEL: inplace_update_one +func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[UPDATE:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> + + // CHECK: return [[UPDATE]] + return %0 : tensor<8x4xf32> +} + +// CHECK-LABEL: inplace_update_three +func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE3:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE4:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE5:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE6:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[RESHAPE2:%.+]] = "xla_hlo.reshape"([[SLICE2]]) + // CHECK-DAG: [[RESHAPE3:%.+]] = "xla_hlo.reshape"([[SLICE3]]) + // CHECK-DAG: [[UPDATE1:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE2:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE3:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> + + // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> + return %0 : tensor<8x8x4xf32> +} + + // CHECK-LABEL: xla_dynamic_update_slice func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { // CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> @@ -4103,6 +3838,21 @@ func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg return %0 : tensor<4xf32> } +//===----------------------------------------------------------------------===// +// AllToAll op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @alltoall_basic +func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { + %group_assignment = "tf.Const" () { + value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> + } : () -> tensor<3x4xi32> + %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32> + // CHECK: xla_hlo.all_to_all + // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> + return %result : tensor<10xf32> +} + //===----------------------------------------------------------------------===// // Cumsum op legalizations. //===----------------------------------------------------------------------===// @@ -4150,177 +3900,11 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor // CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { -// CHECK: [[VAL_1:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x100xi32> -// CHECK: [[VAL_2:%.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<100x100xi32> -// CHECK: [[VAL_3:%.*]] = "xla_hlo.compare"([[VAL_1]], [[VAL_2]]) {comparison_direction = "EQ"} : (tensor<100x100xi32>, tensor<100x100xi32>) -> tensor<100x100xi1> -// CHECK: [[VAL_4:%.*]] = "xla_hlo.convert"([[VAL_3]]) : (tensor<100x100xi1>) -> tensor<100x100xf32> -// CHECK: [[VAL_5:%.*]] = "xla_hlo.broadcast"([[VAL_4]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor<100x100xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_6:%.*]] = "xla_hlo.slice"([[VAL_0]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_7:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_8:%.*]] = "xla_hlo.broadcast"([[VAL_7]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_9:%.*]] = "xla_hlo.broadcast"([[VAL_7]]) {broadcast_sizes = dense<[500, 75]> : tensor<2xi64>} : (tensor) -> tensor<500x75xf32> -// CHECK: [[VAL_10:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_11:%.*]] = "xla_hlo.tuple"([[VAL_10]], [[VAL_6]], [[VAL_8]], [[VAL_9]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_12:%.*]] = "xla_hlo.while"([[VAL_11]]) ( { -// CHECK: ^bb0([[VAL_13:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_14:%.*]] = "xla_hlo.get_tuple_element"([[VAL_13]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_15:%.*]] = xla_hlo.constant dense<75> : tensor -// CHECK: [[VAL_16:%.*]] = "xla_hlo.compare"([[VAL_14]], [[VAL_15]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor -// CHECK: "xla_hlo.return"([[VAL_16]]) : (tensor) -> () -// CHECK: }, { -// CHECK: ^bb0([[VAL_17:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_18:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_19:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_20:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_21:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_22:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_23:%.*]] = "xla_hlo.dynamic-slice"([[VAL_19]], [[VAL_22]], [[VAL_22]], [[VAL_18]]) {slice_sizes = dense<[500, 100, 1]> : tensor<3xi64>} : (tensor<500x100x75xf32>, tensor, tensor, tensor) -> tensor<500x100x1xf32> -// CHECK: [[VAL_24:%.*]] = "xla_hlo.reshape"([[VAL_23]]) : (tensor<500x100x1xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_25:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_26:%.*]] = xla_hlo.constant dense<1.000000e+00> : tensor -// CHECK: [[VAL_27:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_28:%.*]] = "xla_hlo.dynamic-slice"([[VAL_24]], [[VAL_27]], [[VAL_18]]) {slice_sizes = dense<[500, 1]> : tensor<2xi64>} : (tensor<500x100xf32>, tensor, tensor) -> tensor<500x1xf32> -// CHECK: [[VAL_29:%.*]] = "xla_hlo.reshape"([[VAL_28]]) : (tensor<500x1xf32>) -> tensor<500xf32> -// CHECK: [[VAL_30:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100xi32> -// CHECK: [[VAL_31:%.*]] = "xla_hlo.compare"([[VAL_30]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<100xi32>, tensor) -> tensor<100xi1> -// CHECK: [[VAL_32:%.*]] = "xla_hlo.convert"([[VAL_31]]) : (tensor<100xi1>) -> tensor<100xf32> -// CHECK: [[VAL_33:%.*]] = "xla_hlo.multiply"([[VAL_24]], [[VAL_32]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<500x100xf32>, tensor<100xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_34:%.*]] = xla_hlo.multiply [[VAL_33]], [[VAL_33]] : tensor<500x100xf32> -// CHECK: [[VAL_35:%.*]] = "xla_hlo.reduce"([[VAL_34]], [[VAL_25]]) ( { -// CHECK: ^bb0([[VAL_36:%.*]]: tensor, [[VAL_37:%.*]]: tensor): -// CHECK: [[VAL_38:%.*]] = xla_hlo.add [[VAL_36]], [[VAL_37]] : tensor -// CHECK: "xla_hlo.return"([[VAL_38]]) : (tensor) -> () -// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<500x100xf32>, tensor) -> tensor<500xf32> -// CHECK: [[VAL_39:%.*]] = xla_hlo.multiply [[VAL_29]], [[VAL_29]] : tensor<500xf32> -// CHECK: [[VAL_40:%.*]] = xla_hlo.add [[VAL_39]], [[VAL_41:%.*]] : tensor<500xf32> -// CHECK: [[VAL_42:%.*]] = "xla_hlo.sqrt"([[VAL_40]]) : (tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_43:%.*]] = "xla_hlo.compare"([[VAL_41]], [[VAL_25]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500xf32>, tensor) -> tensor<500xi1> -// CHECK: [[VAL_44:%.*]] = "xla_hlo.compare"([[VAL_29]], [[VAL_25]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} : (tensor<500xf32>, tensor) -> tensor<500xi1> -// CHECK: [[VAL_45:%.*]] = "xla_hlo.broadcast"([[VAL_26]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor) -> tensor<500xf32> -// CHECK: [[VAL_46:%.*]] = "xla_hlo.negate"([[VAL_45]]) : (tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_47:%.*]] = "xla_hlo.select"([[VAL_44]], [[VAL_45]], [[VAL_46]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_48:%.*]] = xla_hlo.multiply [[VAL_47]], [[VAL_42]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<500xf32> -// CHECK: [[VAL_49:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_29]], [[VAL_48]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_50:%.*]] = xla_hlo.subtract [[VAL_49]], [[VAL_29]] : tensor<500xf32> -// CHECK: [[VAL_51:%.*]] = xla_hlo.divide [[VAL_50]], [[VAL_49]] : tensor<500xf32> -// CHECK: [[VAL_52:%.*]] = "xla_hlo.broadcast"([[VAL_25]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor) -> tensor<500xf32> -// CHECK: [[VAL_53:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_52]], [[VAL_51]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_54:%.*]] = xla_hlo.subtract [[VAL_29]], [[VAL_49]] : tensor<500xf32> -// CHECK: [[VAL_55:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_45]], [[VAL_54]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_56:%.*]] = "xla_hlo.compare"([[VAL_30]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<100xi32>, tensor) -> tensor<100xi1> -// CHECK: [[VAL_57:%.*]] = "xla_hlo.convert"([[VAL_56]]) : (tensor<100xi1>) -> tensor<100xf32> -// CHECK: [[VAL_58:%.*]] = "xla_hlo.broadcast"([[VAL_57]]) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<100xf32>) -> tensor<1x100xf32> -// CHECK: [[VAL_59:%.*]] = "xla_hlo.divide"([[VAL_33]], [[VAL_55]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500x100xf32>, tensor<500xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_60:%.*]] = "xla_hlo.add"([[VAL_58]], [[VAL_59]]) : (tensor<1x100xf32>, tensor<500x100xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_61:%.*]] = "xla_hlo.reshape"([[VAL_60]]) : (tensor<500x100xf32>) -> tensor<500x1x100xf32> -// CHECK: [[VAL_62:%.*]] = "xla_hlo.dot_general"([[VAL_61]], [[VAL_19]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x1x100xf32>, tensor<500x100x75xf32>) -> tensor<500x1x75xf32> -// CHECK: [[VAL_63:%.*]] = "xla_hlo.dot_general"([[VAL_61]], [[VAL_62]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x1x100xf32>, tensor<500x1x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_64:%.*]] = "xla_hlo.multiply"([[VAL_53]], [[VAL_63]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_65:%.*]] = xla_hlo.subtract [[VAL_19]], [[VAL_64]] : tensor<500x100x75xf32> -// CHECK: [[VAL_66:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x1xi32> -// CHECK: [[VAL_67:%.*]] = "xla_hlo.compare"([[VAL_66]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} : (tensor<100x1xi32>, tensor) -> tensor<100x1xi1> -// CHECK: [[VAL_68:%.*]] = "xla_hlo.convert"([[VAL_67]]) : (tensor<100x1xi1>) -> tensor<100x1xf32> -// CHECK: [[VAL_69:%.*]] = "xla_hlo.compare"([[VAL_66]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<100x1xi32>, tensor) -> tensor<100x1xi1> -// CHECK: [[VAL_70:%.*]] = "xla_hlo.convert"([[VAL_69]]) : (tensor<100x1xi1>) -> tensor<100x1xf32> -// CHECK: [[VAL_71:%.*]] = "xla_hlo.broadcast"([[VAL_70]]) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<100x1xf32>) -> tensor<1x100x1xf32> -// CHECK: [[VAL_72:%.*]] = "xla_hlo.multiply"([[VAL_23]], [[VAL_68]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<500x100x1xf32>, tensor<100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_73:%.*]] = "xla_hlo.multiply"([[VAL_49]], [[VAL_71]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500xf32>, tensor<1x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_74:%.*]] = xla_hlo.add [[VAL_72]], [[VAL_73]] : tensor<500x100x1xf32> -// CHECK: [[VAL_75:%.*]] = "xla_hlo.broadcast_in_dim"([[VAL_74]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<500x100x1xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_76:%.*]] = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<500x100x75xi32> -// CHECK: [[VAL_77:%.*]] = "xla_hlo.compare"([[VAL_76]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500x100x75xi32>, tensor) -> tensor<500x100x75xi1> -// CHECK: [[VAL_78:%.*]] = "xla_hlo.select"([[VAL_77]], [[VAL_75]], [[VAL_65]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_79:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_80:%.*]] = "xla_hlo.broadcast"([[VAL_79]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_81:%.*]] = "xla_hlo.add"([[VAL_80]], [[VAL_60]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<500x100x75xf32>, tensor<500x100xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_82:%.*]] = "xla_hlo.select"([[VAL_77]], [[VAL_81]], [[VAL_80]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_83:%.*]] = xla_hlo.add [[VAL_20]], [[VAL_82]] : tensor<500x100x75xf32> -// CHECK: [[VAL_84:%.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<500x75xi32> -// CHECK: [[VAL_85:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_86:%.*]] = "xla_hlo.broadcast"([[VAL_85]]) {broadcast_sizes = dense<[500, 75]> : tensor<2xi64>} : (tensor) -> tensor<500x75xf32> -// CHECK: [[VAL_87:%.*]] = "xla_hlo.compare"([[VAL_84]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500x75xi32>, tensor) -> tensor<500x75xi1> -// CHECK: [[VAL_88:%.*]] = "xla_hlo.add"([[VAL_86]], [[VAL_53]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500x75xf32>, tensor<500xf32>) -> tensor<500x75xf32> -// CHECK: [[VAL_89:%.*]] = "xla_hlo.select"([[VAL_87]], [[VAL_88]], [[VAL_86]]) : (tensor<500x75xi1>, tensor<500x75xf32>, tensor<500x75xf32>) -> tensor<500x75xf32> -// CHECK: [[VAL_90:%.*]] = xla_hlo.add [[VAL_21]], [[VAL_89]] : tensor<500x75xf32> -// CHECK: [[VAL_91:%.*]] = xla_hlo.constant dense<1> : tensor -// CHECK: [[VAL_92:%.*]] = xla_hlo.add [[VAL_18]], [[VAL_91]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor -// CHECK: [[VAL_93:%.*]] = "xla_hlo.tuple"([[VAL_92]], [[VAL_78]], [[VAL_83]], [[VAL_90]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: "xla_hlo.return"([[VAL_93]]) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> () -// CHECK: }) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_94:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95:%.*]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_96:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_97:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_98:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_99:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_100:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_101:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_0]], [[VAL_94]], [[VAL_100]], [[VAL_98]], [[VAL_99]]) : (tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_102:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_103:%.*]] = "xla_hlo.broadcast"([[VAL_102]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_104:%.*]] = "xla_hlo.slice"([[VAL_96]]) {limit_indices = dense<[500, 100, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_105:%.*]] = "xla_hlo.slice"([[VAL_97]]) {limit_indices = dense<[500, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<500x75xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_106:%.*]] = "xla_hlo.negate"([[VAL_105]]) : (tensor<500x1xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_107:%.*]] = "xla_hlo.multiply"([[VAL_106]], [[VAL_104]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<500x1xf32>, tensor<500x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_108:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_109:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_110:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_103]], [[VAL_107]], [[VAL_109]], [[VAL_109]], [[VAL_108]]) : (tensor<500x100x75xf32>, tensor<500x100x1xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_111:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_112:%.*]] = "xla_hlo.tuple"([[VAL_111]], [[VAL_110]], [[VAL_96]], [[VAL_97]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_113:%.*]] = "xla_hlo.while"([[VAL_112]]) ( { -// CHECK: ^bb0([[VAL_114:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_115:%.*]] = "xla_hlo.get_tuple_element"([[VAL_114]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_116:%.*]] = xla_hlo.constant dense<74> : tensor -// CHECK: [[VAL_117:%.*]] = "xla_hlo.compare"([[VAL_115]], [[VAL_116]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor -// CHECK: "xla_hlo.return"([[VAL_117]]) : (tensor) -> () -// CHECK: }, { -// CHECK: ^bb0([[VAL_118:%.*]]: tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_119:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 0 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor -// CHECK: [[VAL_120:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_121:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_122:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_123:%.*]] = xla_hlo.constant dense<1> : tensor -// CHECK: [[VAL_124:%.*]] = xla_hlo.add [[VAL_119]], [[VAL_123]] : tensor -// CHECK: [[VAL_125:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_126:%.*]] = "xla_hlo.dynamic-slice"([[VAL_121]], [[VAL_125]], [[VAL_125]], [[VAL_124]]) {slice_sizes = dense<[500, 100, 1]> : tensor<3xi64>} : (tensor<500x100x75xf32>, tensor, tensor, tensor) -> tensor<500x100x1xf32> -// CHECK: [[VAL_127:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_128:%.*]] = "xla_hlo.dynamic-slice"([[VAL_122]], [[VAL_127]], [[VAL_124]]) {slice_sizes = dense<[500, 1]> : tensor<2xi64>} : (tensor<500x75xf32>, tensor, tensor) -> tensor<500x1xf32> -// CHECK: [[VAL_129:%.*]] = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<500x100x75xi32> -// CHECK: [[VAL_130:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor -// CHECK: [[VAL_131:%.*]] = "xla_hlo.broadcast"([[VAL_130]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_132:%.*]] = "xla_hlo.compare"([[VAL_129]], [[VAL_124]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GE"} : (tensor<500x100x75xi32>, tensor) -> tensor<500x100x75xi1> -// CHECK: [[VAL_133:%.*]] = "xla_hlo.select"([[VAL_132]], [[VAL_131]], [[VAL_121]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_134:%.*]] = "xla_hlo.dot_general"([[VAL_133]], [[VAL_126]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x1xf32>) -> tensor<500x75x1xf32> -// CHECK: [[VAL_135:%.*]] = "xla_hlo.dot_general"([[VAL_120]], [[VAL_134]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x75x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_136:%.*]] = "xla_hlo.negate"([[VAL_128]]) : (tensor<500x1xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_137:%.*]] = xla_hlo.add [[VAL_126]], [[VAL_135]] : tensor<500x100x1xf32> -// CHECK: [[VAL_138:%.*]] = "xla_hlo.multiply"([[VAL_136]], [[VAL_137]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<500x1xf32>, tensor<500x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_139:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_140:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_120]], [[VAL_138]], [[VAL_139]], [[VAL_139]], [[VAL_124]]) : (tensor<500x100x75xf32>, tensor<500x100x1xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_141:%.*]] = xla_hlo.constant dense<1> : tensor -// CHECK: [[VAL_142:%.*]] = xla_hlo.add [[VAL_119]], [[VAL_141]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor -// CHECK: [[VAL_143:%.*]] = "xla_hlo.tuple"([[VAL_142]], [[VAL_140]], [[VAL_121]], [[VAL_122]]) : (tensor, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: "xla_hlo.return"([[VAL_143]]) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> () -// CHECK: }) : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_144:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145:%.*]]) {index = 1 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_146:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145]]) {index = 2 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_147:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145]]) {index = 3 : i32} : (tuple, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_148:%.*]] = "xla_hlo.slice"([[VAL_101]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<[0, 0, 75]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x0xf32> -// CHECK: [[VAL_149:%.*]] = "xla_hlo.dot_general"([[VAL_144]], [[VAL_148]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x0xf32>) -> tensor<500x75x0xf32> -// CHECK: [[VAL_150:%.*]] = "xla_hlo.dot_general"([[VAL_96]], [[VAL_149]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x75x0xf32>) -> tensor<500x100x0xf32> -// CHECK: [[VAL_151:%.*]] = xla_hlo.add [[VAL_148]], [[VAL_150]] : tensor<500x100x0xf32> -// CHECK: [[VAL_152:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_153:%.*]] = xla_hlo.constant dense<75> : tensor -// CHECK: [[VAL_154:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_155:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_101]], [[VAL_151]], [[VAL_154]], [[VAL_152]], [[VAL_153]]) : (tensor<500x100x75xf32>, tensor<500x100x0xf32>, tensor, tensor, tensor) -> tensor<500x100x75xf32> -// CHECK: [[VAL_156:%.*]] = "xla_hlo.slice"([[VAL_5]]) {limit_indices = dense<[500, 100, 100]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x100xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_157:%.*]] = "xla_hlo.dot_general"([[VAL_156]], [[VAL_144]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x100xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_158:%.*]] = "xla_hlo.dot_general"([[VAL_157]], [[VAL_96]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_159:%.*]] = xla_hlo.add [[VAL_156]], [[VAL_158]] : tensor<500x100x100xf32> -// CHECK: [[VAL_160:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_161:%.*]] = xla_hlo.constant dense<0> : tensor -// CHECK: [[VAL_162:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_5]], [[VAL_159]], [[VAL_161]], [[VAL_161]], [[VAL_160]]) : (tensor<500x100x100xf32>, tensor<500x100x100xf32>, tensor, tensor, tensor) -> tensor<500x100x100xf32> -// CHECK: [[VAL_163:%.*]] = "xla_hlo.slice"([[VAL_162]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x100xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_164:%.*]] = "xla_hlo.slice"([[VAL_155]]) {limit_indices = dense<[500, 75, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x75x75xf32> -// CHECK: return [[VAL_163]], [[VAL_164]] : tensor<500x100x75xf32>, tensor<500x75x75xf32> + // The tf.Qr lowering is a full algorithm that is not effective to verify with + // FileCheck. Just verify that it converted. + // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is + // really only applicable to certain legacy uses. + // CHECK-NOT: "tf.Qr" %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index d25a84d0e25..9f27a204baf 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s +// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { @@ -42,40 +42,6 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 return %4 : tensor<4xi32> } -// Broadcasting is not currently supported. -// TODO(suderman):Future pass should take all broadcasted binary ops and convert -// them to separate broadcast and binary op. -// CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { -func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) { - name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %1 = "xla_hlo.multiply"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %1 = "xla_hlo.multiply"(%0, %arg1) { - name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %2 = "xla_hlo.subtract"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %2 = "xla_hlo.subtract"(%1, %arg1) { - name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %3 = "xla_hlo.divide"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %3 = "xla_hlo.divide"(%2, %arg1) { - name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %4 = "xla_hlo.remainder"(%3, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %4 = "xla_hlo.remainder"(%3, %arg1) { - broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: return %4 : tensor<4x4xf32> - return %4 : tensor<4x4xf32> -} - // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 013748fea28..99b1766e73c 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -24,9 +24,9 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // CHECK-LABEL: func @fusion // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic @@ -36,9 +36,9 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: addf // TILED: linalg.generic @@ -46,8 +46,8 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // PLOOP-LABEL: func @fusion // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: addf // PLOOP: linalg.generic @@ -94,9 +94,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // CHECK-LABEL: func @fusion // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: linalg.generic // CHECK: subf @@ -107,9 +107,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: linalg.generic // TILED: subf @@ -118,8 +118,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // PLOOP-LABEL: func @fusion_of_three // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: linalg.generic // PLOOP: subf @@ -147,11 +147,11 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // CHECK-LABEL: func @fusion_4d // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic @@ -161,9 +161,9 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: addf // TILED: linalg.generic @@ -171,8 +171,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // PLOOP-LABEL: func @fusion_4d // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: addf // PLOOP: linalg.generic diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir index 5b763cde2ed..c640b395f4d 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir @@ -50,19 +50,19 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // Parallel loop to initialize the output buffer. // CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref -// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) { // CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // Parallel loop over source buffer to compute scattered values. -// CHECK: loop.parallel ([[II:%.*]], [[JJ:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[II:%.*]], [[JJ:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { // Window loop w.r.t. first dim. // CHECK: [[SEL_RES_I:%.*]]:4 -// CHECK-SAME: = loop.for [[WIN_I:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: = scf.for [[WIN_I:%.*]] = [[C0]] to [[C3]] step [[C1]] // CHECK-SAME: iter_args( // CHECK-SAME: [[SEL_I_0:%.*]] = [[C0]], [[SEL_J_0:%.*]] = [[C0]], // CHECK-SAME: [[SEL_VAL_0:%.*]] = [[C0_F32]], @@ -71,7 +71,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // Window loop w.r.t. second dim. // CHECK: [[SEL_RES_J:%.*]]:4 -// CHECK-SAME: = loop.for [[WIN_J:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: = scf.for [[WIN_J:%.*]] = [[C0]] to [[C3]] step [[C1]] // CHECK-SAME: iter_args( // CHECK-SAME: [[SEL_I:%.*]] = [[SEL_I_0]], [[SEL_J:%.*]] = [[SEL_J_0]], // CHECK-SAME: [[SEL_VAL:%.*]] = [[SEL_VAL_0]], @@ -102,14 +102,14 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // be applied, current selected ivs (SEL_I, SEL_J) and value (SEL_VAL) are // returned in that case. // CHECK: [[IF_INBOUNDS_RES:%.*]]:4 -// CHECK-SAME: = loop.if [[INBOUNDS_1]] -> (index, index, f32, i1) { +// CHECK-SAME: = scf.if [[INBOUNDS_1]] -> (index, index, f32, i1) { // INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true // CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]] // CHECK: [[IF_INIT_RES:%.*]]:4 - // CHECK-SAME: = loop.if [[SEL_INIT]] -> (index, index, f32, i1) { + // CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) { // INIT-THEN-BODY, i.e. INBOUNDS == true and INIT = true @@ -133,40 +133,40 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // Depending on PRED, return ARG ivs & elem or current select ivs and value. - // CHECK: [[IF_PRED_RES:%.*]]:4 = loop.if [[PRED]] - // CHECK: loop.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], [[CTRUE]] + // CHECK: [[IF_PRED_RES:%.*]]:4 = scf.if [[PRED]] + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], [[CTRUE]] // CHECK: } else { - // CHECK: loop.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], [[SEL_INIT]] + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], [[SEL_INIT]] // CHECK: } // INIT-THEN-BODY yield. - // CHECK: loop.yield [[IF_PRED_RES]]#0, [[IF_PRED_RES]]#1, + // CHECK: scf.yield [[IF_PRED_RES]]#0, [[IF_PRED_RES]]#1, // CHECK-SAME: [[IF_PRED_RES]]#2, [[IF_PRED_RES]]#3 // INIT-ELSE-BODY, i.e. if INBOUNDS == TRUE and INIT == FALSE, returns ARG // ivs and element without computing Select function. - // CHECK: loop.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], // CHECK-SAME: [[CTRUE]] : index, index, f32, i1 // CHECK: } // INBOUNDS-THEN-BODY yield. - // CHECK: loop.yield [[IF_INIT_RES]]#0, [[IF_INIT_RES]]#1, [[IF_INIT_RES]]#2, + // CHECK: scf.yield [[IF_INIT_RES]]#0, [[IF_INIT_RES]]#1, [[IF_INIT_RES]]#2, // CHECK-SAME: [[IF_INIT_RES]]#3 : index, index, f32, i1 // CHECK: } // INBOUNDS-ELSE-REGION, i.e. if INBOUNDS == FALSE // We are in the pad area, return current iter_args. - // CHECK: loop.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], // CHECK-SAME: [[SEL_INIT]] : index, index, f32, i1 // CHECK: } // Window loop w.r.t. second dim yield. -// CHECK: loop.yield [[IF_INBOUNDS_RES]]#0, [[IF_INBOUNDS_RES]]#1, +// CHECK: scf.yield [[IF_INBOUNDS_RES]]#0, [[IF_INBOUNDS_RES]]#1, // CHECK-SAME: [[IF_INBOUNDS_RES]]#2, [[IF_INBOUNDS_RES]]#3 // CHECK: } // Window loop w.r.t. first dim yield. -// CHECK: loop.yield [[SEL_RES_J]]#0, [[SEL_RES_J]]#1, [[SEL_RES_J]]#2, +// CHECK: scf.yield [[SEL_RES_J]]#0, [[SEL_RES_J]]#1, [[SEL_RES_J]]#2, // CHECK-SAME: [[SEL_RES_J]]#3 : index, index, f32, i1 // CHECK: } @@ -196,4 +196,4 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: atomic_yield [[RES]] : f32 // Parallel loop over source buffer yield -// CHECK: loop.yield +// CHECK: scf.yield diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir index 08ba9f02f3e..aaf65b5a38a 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-affine.mlir @@ -5,15 +5,15 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, %result: memref<4x3x2x1xf32>) -> () { // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 { - // CHECK-NEXT: affine.for %[[J:.*]] = 0 to 3 { - // CHECK-NEXT: affine.for %[[K:.*]] = 0 to 2 { - // CHECK-NEXT: affine.for %[[L:.*]] = 0 to 1 { - // CHECK-NEXT: %[[LHS:.*]] = load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> - // CHECK-NEXT: %[[RHS:.*]] = load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> - // CHECK-NEXT: %[[MIN_PREDICATE:.*]] = cmpf "olt", %[[LHS]], %[[RHS]] : f32 - // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 - // CHECK-NEXT: store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> - // CHECK: return + // CHECK-NEXT: affine.for %[[J:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[K:.*]] = 0 to 2 { + // CHECK-NEXT: affine.for %[[L:.*]] = 0 to 1 { + // CHECK-NEXT: %[[LHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> + // CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> + // CHECK-NEXT: %[[MIN_PREDICATE:.*]] = cmpf "olt", %[[LHS]], %[[RHS]] : f32 + // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 + // CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> + // CHECK: return "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () return diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir index 4d878cee6f4..16ffbf241b0 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir @@ -22,7 +22,7 @@ func @reduce(%arg: memref<100x10xf32>, // CHECK-DAG: %[[LB:.*]] = constant 0 : index // CHECK-DAG: %[[UB:.*]] = constant 10 : index // CHECK-DAG: %[[STEP:.*]] = constant 1 : index -// CHECK: loop.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref // CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref // CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index 0fc30ed4901..626e905695c 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -411,6 +411,19 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) + : (memref<2x2xf32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () @@ -523,6 +536,48 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 +// ----- + +// CHECK-LABEL: func @complex +func @complex(%real: memref<2x2xf32>, + %imag: memref<2x2xf32>, + %cplx: memref<2x2xcomplex>) { + "xla_lhlo.complex"(%real, %imag, %cplx) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex): +// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @real +func @real(%cplx: memref<2x2xcomplex>, + %real: memref<2x2xf32>) { + "xla_lhlo.real"(%cplx, %real) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[REAL_OUT:.*]]: f32): +// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[REAL]] : f32 + +// ----- + +// CHECK-LABEL: func @imag +func @imag(%cplx: memref<2x2xcomplex>, + %imag: memref<2x2xf32>) { + "xla_lhlo.imag"(%cplx, %imag) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[IMAG_OUT:.*]]: f32): +// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[IMAG]] : f32 // ----- @@ -581,3 +636,16 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { + "xla_lhlo.reverse"(%arg0, %arg1) { + dimensions = dense<1> : tensor<1xi64> + } : (memref<2x3xf32>, memref<2x3xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir new file mode 100644 index 00000000000..16aad8f7cf3 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-llvm.mlir @@ -0,0 +1,65 @@ +// RUN: xla-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: func @static_memref_cast +func @static_memref_cast(%buf : memref<10x1x5xf32>) { + %0 = xla_lhlo.static_memref_cast %buf + : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> + return +} +// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]] +// CHECK: llvm.insertvalue +// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]] + +// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]] +// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]] + +// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]] +// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]] + +// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]] + +// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]] +// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]] +// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]] +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]] + +// ----- + +// CHECK-LABEL: func @dynamic_memref_cast +func @dynamic_memref_cast(%buf : memref) { + %size_X = constant 10 : index + %size_Y = constant 50 : index + %stride_X = constant 1 : index + %stride_Y = constant 0 : index + %0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] + : memref -> memref + return +} +// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 +// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + +// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]] + +// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]] + +// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]] + +// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]] diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index cb169e060ef..32c367f97d6 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -22,13 +22,13 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK-DAG: [[C10:%.*]] = constant 10 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -37,12 +37,12 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: loop.yield +// CHECK: scf.yield // ----- @@ -66,10 +66,10 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[I:%.*]]) = ([[C0]]) +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]]) // CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}} -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -78,9 +78,9 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] +// CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] // ----- @@ -107,13 +107,13 @@ func @dynamic_reduce(%arg: memref, // CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], 1 : memref // CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], 2 : memref // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -122,12 +122,12 @@ func @dynamic_reduce(%arg: memref, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: loop.yield +// CHECK: scf.yield // ----- @@ -158,9 +158,9 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK-DAG: [[C56:%.*]] = constant 56 : index // CHECK-DAG: [[C112:%.*]] = constant 112 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref -// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel // CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]]) // CHECK-SAME: init ([[INIT]]) -> f32 { @@ -177,15 +177,15 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]] // CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]] -// CHECK: [[ELEM_TO_REDUCE:%.*]] = loop.if [[IN_BOUNDS_1]] -> (f32) { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) { // CHECK: [[OPERAND_ELEM:%.*]] = // CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] -// CHECK: loop.yield [[OPERAND_ELEM]] : f32 +// CHECK: scf.yield [[OPERAND_ELEM]] : f32 // CHECK: } else { -// CHECK: loop.yield [[INIT]] : f32 +// CHECK: scf.yield [[INIT]] : f32 // CHECK: } -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -194,12 +194,12 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 23e9d9b68e0..1a44428d2e9 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -178,3 +178,73 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m } ) : () -> () return } + +// ----- + +// CHECK-LABEL: func @case_memref +func @case_memref(%index: memref, %operand_1: memref, %operand_2: memref, %operand_3: memref, %out: memref) -> () { + "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { + ^bb0(%arg0: memref): + "xla_lhlo.negate"(%arg0, %out) : (memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref): + "xla_lhlo.copy"(%arg0, %out) : (memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref): + "xla_lhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } + ) : (memref, memref, memref, memref, memref) -> () + return +} + +// ----- + +func @static_memref_cast(%in: memref<10x1xf32>) { + %out = xla_lhlo.static_memref_cast %in + : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> + return +} +// CHECK-LABEL: func @static_memref_cast + +// ----- + +func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { + // expected-error @+1 {{operand must have static shape}} + %out = xla_lhlo.static_memref_cast %in + : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> + return +} + +// ----- + +func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { + // expected-error @+1 {{result must have static shape}} + %out = xla_lhlo.static_memref_cast %in + : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> + return +} + +// ----- + +func @dynamic_memref_cast(%in: memref) { + %size = constant 10 : index + %step = constant 1 : index + %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + : memref -> memref + return +} +// CHECK-LABEL: func @dynamic_memref_cast + +// ----- + +func @dynamic_memref_cast_incompatible_result_type(%in: memref) { + // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} + %size = constant 10 : index + %step = constant 1 : index + %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + : memref -> memref + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir index 35a5ae549d5..81376761467 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s +// RUN: xla-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s --dump-input-on-failure // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { @@ -15,21 +15,6 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @add_broadcast -func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @add_unranked func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -60,21 +45,6 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @sub_broadcast -func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.subtract"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.subtract"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.subtract"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @sub_unranked func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -109,25 +79,6 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @mul_broadcast -func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.multiply"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32> - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @mul_unranked func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) @@ -186,45 +137,6 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // ----- -// CHECK-LABEL: @div_broadcast -func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.multiply"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) - // CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.multiply"(%arg0, [[VAL0]]) - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.divide"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.divide"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.divide"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex>, tensor<2xcomplex>) -> (tensor<1x2xcomplex>) - - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL10]], [[VAL11]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - -// ----- - // CHECK-LABEL: @div_unranked func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index 4050340ce49..55b55c7b4e2 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -1,214 +1,5 @@ // RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s -// CHECK-LABEL: @addBroadcastRhs -func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastLhs -func @addBroadcastLhs(%arg0: tensor<4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastMultidimension -func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>, tensor<1x1x4xf32>) -> tensor<1x1x4xf32> - return %0 : tensor<1x1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastBothArgs -func @addBroadcastBothArgs(%arg0: tensor<1x2xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<3x2x2xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<3x2x2xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>, tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - return %0 : tensor<3x2x2xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastScalar -func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @addWithoutBroadcast -func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @addUnranked -func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<*xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @atan2BroadcastRhs -func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @divBroadcastRhs -func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.divide %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @maxBroadcastRhs -func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.maximum %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.maximum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @minBroadcastRhs -func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.minimum %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.minimum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @mulBroadcastRhs -func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.multiply %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @powBroadcastRhs -func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.power %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.power"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @remainderBroadcastRhs -func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftLeftBroadcastRhs -func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs -func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightLogicalBroadcastRhs -func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @subBroadcastRhs -func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.subtract %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @andBroadcastRhs -func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @orBroadcastRhs -func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @xorBroadcastRhs -func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - // CHECK-LABEL: @clampBroadcast // CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { @@ -218,69 +9,3 @@ func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor %0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> } - -// ----- - -// CHECK-LABEL: @compareBroadcastRhs -func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%arg0, %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1> - return %0 : tensor<1x4xi1> -} - -// ----- - -// CHECK-LABEL: @dynamicCompareBroadcastRhs -func @dynamicCompareBroadcastRhs(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %c1 = constant 1 : index - // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor - // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index - // CHECK-NEXT: %[[SEL:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[SEL]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor, tensor) -> tensor - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAdd -func @dynamicBroadcastAdd(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %c1 = constant 1 : index - // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor - // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index - // CHECK-NEXT: %[[SEL:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[SEL]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAddScalar -func @dynamicBroadcastAddScalar(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %[[DIM1:.*]] = dim %arg0, 1 : tensor - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[DIM1]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<2xi32>) -> tensor - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - return %0 : tensor -} diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 8cb63311657..0a69ee93aee 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -156,6 +156,98 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x // ----- +func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1, %arg0) : (tensor, tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{expects operand 2 to be of type 'tensor', but found 'tensor'}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = xla_hlo.constant dense<2.0> : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = xla_hlo.constant dense<2> : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { + // expected-error@+1 {{cannot have empty regions}} + "xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor + return +} + +// ----- + // CHECK-LABEL: func @comp_eq func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> @@ -453,14 +545,6 @@ func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) - // ----- -// CHECK-LABEL: @scalars_to_dimension_tensor -func @scalars_to_dimension_tensor(%arg0: i32, %arg1: i32) -> tensor<2xi32> { - %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32> - return %0 : tensor<2xi32> -} - -// ----- - // CHECK-LABEL: func @select func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir new file mode 100644 index 00000000000..9f54e40dcaa --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir @@ -0,0 +1,60 @@ +// RUN: xla-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s --dump-input=fail + +// Tests sinking constants to a while loop. + +// CHECK-LABEL: func @sink_const_to_while +func @sink_const_to_while(%arg0: tensor) -> tensor { + // CHECK-NEXT: xla_hlo.while + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1A:.+]]: tensor + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + // CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]]) + %1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1B:.+]]: tensor + // CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + // CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]] + %2 = xla_hlo.add %arg1, %arg1 : tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]] + %3 = xla_hlo.add %c1, %2 : tensor + // CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]] + %4 = xla_hlo.add %c1, %3 : tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} + +// Tests sinking constants to a conditional op. + +// CHECK-LABEL: func @sink_const_to_conditional +func @sink_const_to_conditional(%arg0: tensor) -> tensor { + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> + // CHECK: xla_hlo.if + %2 = "xla_hlo.if"(%0, %1, %1) ( { + ^bb0(%arg1: tuple>): + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + %3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]], + %4 = xla_hlo.add %c0, %3 : tensor + %5 = "xla_hlo.tuple"(%4) : (tensor) -> tuple> + "xla_hlo.return"(%5) : (tuple>) -> () + }, { + ^bb0(%arg1: tuple>): + // CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], + %7 = xla_hlo.add %c1, %6 : tensor + %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> + "xla_hlo.return"(%8) : (tuple>) -> () + }) : (tensor, tuple>, tuple>) -> tuple> + %9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor + return %9 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir new file mode 100644 index 00000000000..dba9e8b61ca --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -0,0 +1,99 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +func @main() -> tensor { + %cst = constant {name = "constant"} dense<1> : tensor + %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor + %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor + %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]]) +// CHECK: } + +// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]]) +// CHECK: } + +// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]]) +// CHECK: } + +// CHECK-LABEL: ENTRY +// CHECK-SAME: () -> f32[] + +// CHECK: %[[INDEX:.*]] = s32[] constant(1) +// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56) +// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12) +// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13) +// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} + +// ----- + +func @main() -> (tensor, tensor) { + %cst = constant {name = "constant"} dense<1> : tensor + %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor + %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor + %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor) -> tensor + "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor) -> tensor + "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor) -> tensor + "xla_hlo.return"(%1, %1) : (tensor, tensor) -> () + }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]]) +// CHECK: } + +// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]]) +// CHECK: } + +// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]]) +// CHECK: } + +// CHECK-LABEL: ENTRY +// CHECK-SAME: () -> (f32[], f32[]) + +// CHECK: %[[INDEX:.*]] = s32[] constant(1) +// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56) +// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12) +// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13) +// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} +// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0 +// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1 +// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt new file mode 100644 index 00000000000..2ff223cd480 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -0,0 +1,46 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule Indexed_Conditional + +%Negate (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + ROOT %negate = f32[] negate(f32[] %x) +} + +%Identity (y: f32[]) -> f32[] { + %y = f32[] parameter(0) + ROOT %copy = f32[] copy(f32[] %y) +} + +%Floor (z: f32[]) -> f32[] { + %z = f32[] parameter(0) + ROOT %floor = f32[] floor(f32[] %z) +} + +ENTRY %indexed_conditional () -> f32[] { + %constant = s32[] constant(1) + %constant.1 = f32[] constant(56) + %constant.2 = f32[] constant(12) + %constant.3 = f32[] constant(13) + ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor} +} + +// CHECK-LABEL: func @main() -> tensor +// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor +// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor +// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor +// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor +// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { +// CHECK: ^bb0(%[[ARG_1:.*]]: tensor): +// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%[[ARG_2:.*]]: tensor): +// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%[[ARG_3:.*]]: tensor): +// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor) -> () +// CHECK: }) {name = "{{.*}}"} : (tensor, tensor, tensor, tensor) -> tensor +// CHECK: return %[[RESULT]] : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 3650307ea94..20b43e8633d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure // CHECK: HloModule func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token { @@ -96,34 +96,6 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor // ----- -// CHECK: HloModule -func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> { - // Same rank degenerate broadcast - // CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0) - // CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]]) - // CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]]) - // CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1) - // CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]]) - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - - // Broadcast up rank - // CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2} - // CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2) - // CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]]) - %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - - // Broadcast up rank + degenerate broadcast - // CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2} - // CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]]) - // CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2} - // CHECK: ROOT - // CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]]) - %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - return %2 : tensor<2x3x4xi32> -} - -// ----- - // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> @@ -294,6 +266,12 @@ func @main() { // CHECK: f16[4] constant({1, -4, -65504, 0.015625} %cst_8 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> + // CHECK: c64[] constant((1, 0)) + %cst_9 = constant dense<(1.000000e+00,0.000000e+00)> : tensor> + + // CHECK: c128[] constant((1, 0)) + %cst_10 = constant dense<(1.000000e+00,0.000000e+00)> : tensor> + return } @@ -1038,3 +1016,16 @@ func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = u8[4] parameter(0) // ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]]) + +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) { + %0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32> + return %1 : tensor<*xi32> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = s32[4] parameter(0) +// ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir similarity index 98% rename from tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir rename to tensorflow/compiler/mlir/xla/tests/translate/if.mlir index e510a2aa35f..6542966fc7c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir @@ -41,7 +41,7 @@ func @main(%arg0: tensor) -> tuple> { %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> // CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]] - %2 = "xla_hlo.conditional"(%0, %1, %1) ( { + %2 = "xla_hlo.if"(%0, %1, %1) ( { ^bb0(%arg1: tuple>): %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor %7 = "xla_hlo.log"(%6) : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt similarity index 97% rename from tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt rename to tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt index 00f6ec2d308..d2c6e669e9b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt @@ -29,7 +29,7 @@ ENTRY %tfcompile.20 { // CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]]) %tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"} - // CHECK: [[R3:%.+]] = "xla_hlo.conditional"([[R1]], [[R2]], [[R2]]) ( { + // CHECK: [[R3:%.+]] = "xla_hlo.if"([[R1]], [[R2]], [[R2]]) ( { // CHECK: ^bb0([[A1:%.+]]: tuple>): // CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]]) // CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 75471e3a090..af45f84b34d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s --dump-input-on-failure HloModule main @@ -20,29 +20,6 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } -// This test is more thorough than those of the the other binary ops to test -// their shared functionality. - -// CHECK-LABEL: func @test_add -%test_add (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] { - %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[4] parameter(1) - %Arg_2.3 = f32[] parameter(2) - %Arg_3.4 = f32[] parameter(3) - - // Add two tensors - // CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"} - %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) - - // Add two scalars - // CHECK-NEXT: xla_hlo.add %arg2, %arg3 - %add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4) - - // Add a tensor and scalar - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4) -} - // CHECK-LABEL: func @test_after_all // CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token %test_after_all (token0: token[], token1: token[] ) -> token[] { @@ -159,11 +136,11 @@ add { } -// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> { -%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] { +// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> { +%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] { %Arg_0.1 = f32[3] parameter(0) %Arg_1.2 = f32[3] parameter(1) - %Arg_2.3 = f32[1] parameter(2) + %Arg_2.3 = f32[3] parameter(2) // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ @@ -172,7 +149,7 @@ add { %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -212,10 +189,14 @@ add { // CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> %constant.3 = bf16[4] constant({1, 2, 3, 4}) + // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor> + %constant.4 = c64[] constant((1, 0)) + + // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor> + %constant.5 = c128[] constant((1, 0)) + // CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16> - ROOT %constant.4 = f16[4] constant({1, -4, -65504, 0.015625}) - - + ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -276,19 +257,19 @@ add { ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1} } -// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf64> { -%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] { +// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> { +%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] { %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[] parameter(1) + %Arg_1.2 = f32[4] parameter(1) // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor) -> tensor - %convert.4 = f64[] convert(f32[] %Arg_1.2) + // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + %convert.4 = f64[4] convert(f32[4] %Arg_1.2) - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4) + // CHECK-NEXT: xla_hlo.add %0, %1 + ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4) } // CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir index 9778772e250..3d8646e7fb9 100644 --- a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir @@ -106,24 +106,19 @@ func @batchNormInference_dynamic_shape( -> tensor { // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor - // CHECK-DAG: %[[INDEX_CAST:.+]] = index_cast %[[DIM]] : index to i32 - // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INDEX_CAST]]) : (i32) -> tensor<1xi32> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_0:.+]] = index_cast %[[INPUT_DIM_0]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_1:.+]] = index_cast %[[INPUT_DIM_1]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_2:.+]] = index_cast %[[INPUT_DIM_2]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor - // CHECK-DAG: %[[INPUT_INDEX_CAST_3:.+]] = index_cast %[[INPUT_DIM_3]] : index to i32 - // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_INDEX_CAST_0]], %[[INPUT_INDEX_CAST_1]], %[[INPUT_INDEX_CAST_2]], %[[INPUT_INDEX_CAST_3]]) : (i32, i32, i32, i32) -> tensor<4xi32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xi32>) -> tensor + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/xla-transform-unranked-hlo.mlir b/tensorflow/compiler/mlir/xla/tests/xla-transform-unranked-hlo.mlir new file mode 100644 index 00000000000..8b26a5e4121 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/xla-transform-unranked-hlo.mlir @@ -0,0 +1,65 @@ +// RUN: xla-opt -transform-unranked-hlo -split-input-file %s | FileCheck --dump-input=fail %s + +// Check the validity of expected IR. +// CHECK-LABEL: @sqr_transform_result +func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { + + // Flatten operand shape. + %shape = shape.shape_of %a : tensor<*xf32> + %num_elements = shape.num_elements %shape + %num_elements_as_index = shape.size_to_index %num_elements + %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> + %flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape) + : (tensor<*xf32>, tensor<1xindex>) -> tensor + + // Apply operation. + %flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor) -> tensor + + // Restore original shape. + %shape_as_extent_tensor = "shape.to_extent_tensor"(%shape) + : (!shape.shape) -> tensor + %b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + : (tensor, tensor) -> tensor<*xf32> + + return %b : tensor<*xf32> +} + +// ----- +// Check transformation of unranked code. +// CHECK-LABEL: @sqrt +// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) +func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> + // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] + // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor + // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = "shape.to_extent_tensor"(%[[SHAPE]]) : (!shape.shape) -> tensor + // CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: return %[[B]] : tensor<*xf32> + %b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> + return %b : tensor<*xf32> +} + +// ----- +// Not transformed when ranked. +// CHECK-LABEL: @sqrt_ranked +// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>) +func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { + // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> + // CHECK-NEXT: return %[[B]] : tensor<3x?xf32> + %b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> + return %b : tensor<3x?xf32> +} + +// ----- +// Not transformed when statically shaped. +// CHECK-LABEL: @sqrt_static +// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>) +func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-NEXT: return %[[B]] : tensor<2x3xf32> + %b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %b : tensor<2x3xf32> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc deleted file mode 100644 index 640b9b84622..00000000000 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.cc +++ /dev/null @@ -1,485 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for computing proper alloc and dealloc positions. -// The main class is the BufferAssignment class that realizes this analysis. -// In order to put allocations and deallocations at safe positions, it is -// significantly important to put them into the proper blocks. However, the -// liveness analysis does not pay attention to aliases, which can occur due to -// branches (and their associated block arguments) in general. For this purpose, -// BufferAssignment firstly finds all possible aliases for a single value (using -// the BufferAssignmentAliasAnalysis class). Consider the following example: -// -// ^bb0(%arg0): -// cond_br %cond, ^bb1, ^bb2 -// ^bb1: -// br ^exit(%arg0) -// ^bb2: -// %new_value = ... -// br ^exit(%new_value) -// ^exit(%arg1): -// return %arg1; -// -// Using liveness information on its own would cause us to place the allocs and -// deallocs in the wrong block. This is due to the fact that %new_value will not -// be liveOut of its block. Instead, we have to place the alloc for %new_value -// in bb0 and its associated dealloc in exit. Using the class -// BufferAssignmentAliasAnalysis, we will find out that %new_value has a -// potential alias %arg1. In order to find the dealloc position we have to find -// all potential aliases, iterate over their uses and find the common -// post-dominator block. In this block we can safely be sure that %new_value -// will die and can use liveness information to determine the exact operation -// after which we have to insert the dealloc. Finding the alloc position is -// highly similar and non- obvious. Again, we have to consider all potential -// aliases and find the common dominator block to place the alloc. -// -// TODO(dfki): -// The current implementation does not support loops. The only thing that -// is currently missing is a high-level loop analysis that allows us to move -// allocs and deallocs outside of the loop blocks. - -#include "tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h" - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "absl/memory/memory.h" - -namespace mlir { -namespace xla { -namespace { - -//===----------------------------------------------------------------------===// -// BufferAssignmentAliasAnalysis -//===----------------------------------------------------------------------===// - -/// A straight-forward alias analysis which ensures that all aliases of all -/// values will be determined. This is a requirement for the BufferAssignment -/// class since you need to determine safe positions to place alloc and -/// deallocs. -class BufferAssignmentAliasAnalysis { - public: - using ValueSetT = SmallPtrSet; - - public: - /// Constructs a new alias analysis using the op provided. - BufferAssignmentAliasAnalysis(Operation* op) { build(op->getRegions()); } - - /// Finds all immediate and indirect aliases this value could potentially - /// have. Note that the resulting set will also contain the value provided as - /// it is an alias of itself. - ValueSetT resolve(Value value) const { - ValueSetT result; - resolveRecursive(value, result); - return result; - } - - private: - /// Recursively determines alias information for the given value. It stores - /// all newly found potential aliases in the given result set. - void resolveRecursive(Value value, ValueSetT& result) const { - if (!result.insert(value).second) { - return; - } - auto it = aliases.find(value); - if (it == aliases.end()) return; - for (auto alias : it->second) { - resolveRecursive(alias, result); - } - } - - /// This function constructs a mapping from values to its immediate aliases. - /// It iterates over all blocks, gets their predecessors, determines the - /// values that will be passed to the corresponding block arguments and - /// inserts them into map. - void build(MutableArrayRef regions) { - for (Region& region : regions) { - for (Block& block : region) { - // Iterate over all predecessor and get the mapped values to their - // corresponding block arguments values. - for (auto pred : block.getPredecessors()) { - // Determine the current successor index of the current predecessor. - unsigned successorIndex = std::distance( - pred->getSuccessors().begin(), - llvm::find_if(pred->getSuccessors(), [&](Block* successor) { - return successor == █ - })); - // Get the terminator and the values that will be passed to our block. - if (auto branchInterface = - dyn_cast(pred->getTerminator())) { - // Query the branch op interace to get the successor operands. - auto successorOps = - branchInterface.getSuccessorOperands(successorIndex); - if (successorOps.hasValue()) { - // Build the actual mapping of values to their immediate aliases. - for (auto arg : block.getArguments()) { - Value predecessorArgValue = - successorOps.getValue()[arg.getArgNumber()]; - aliases[predecessorArgValue].insert(arg); - } - } - } - } - } - } - } - - /// Maps values to all immediate aliases this value can have. - llvm::DenseMap aliases; -}; - -//===----------------------------------------------------------------------===// -// BufferAssignmentPositions -//===----------------------------------------------------------------------===// - -/// Stores proper alloc and dealloc positions to place dialect-specific alloc -/// and dealloc operations. -struct BufferAssignmentPositions { - public: - BufferAssignmentPositions() - : allocPosition(nullptr), deallocPosition(nullptr) {} - - /// Creates a new positions tuple including alloc and dealloc positions. - BufferAssignmentPositions(Operation* allocPosition, - Operation* deallocPosition) - : allocPosition(allocPosition), deallocPosition(deallocPosition) {} - - /// Returns the alloc position before which the alloc operation has to be - /// inserted. - Operation* getAllocPosition() const { return allocPosition; } - - /// Returns the dealloc position after which the dealloc operation has to be - /// inserted. - Operation* getDeallocPosition() const { return deallocPosition; } - - private: - Operation* allocPosition; - Operation* deallocPosition; -}; - -//===----------------------------------------------------------------------===// -// BufferAssignmentAnalysis -//===----------------------------------------------------------------------===// - -// The main buffer assignment analysis used to place allocs and deallocs. -class BufferAssignmentAnalysis { - public: - using DeallocSetT = SmallPtrSet; - - public: - BufferAssignmentAnalysis(Operation* op) - : operation(op), - liveness(op), - dominators(op), - postDominators(op), - aliases(op) {} - - /// Computes the actual positions to place allocs and deallocs for the given - /// value. - BufferAssignmentPositions computeAllocAndDeallocPositions(Value value) const { - if (value.use_empty()) { - return BufferAssignmentPositions(value.getDefiningOp(), - value.getDefiningOp()); - } - // Get all possible aliases - auto possibleValues = aliases.resolve(value); - return BufferAssignmentPositions(getAllocPosition(value, possibleValues), - getDeallocPosition(value, possibleValues)); - } - - /// Finds all associated dealloc nodes for the alloc nodes using alias - /// information. - DeallocSetT findAssociatedDeallocs(AllocOp alloc) const { - DeallocSetT result; - auto possibleValues = aliases.resolve(alloc); - for (auto alias : possibleValues) { - for (auto user : alias.getUsers()) { - if (isa(user)) result.insert(user); - } - } - return result; - } - - /// Dumps the buffer assignment information to the given stream. - void print(raw_ostream& os) const { - os << "// ---- Buffer Assignment -----\n"; - - for (Region& region : operation->getRegions()) - for (Block& block : region) - for (Operation& operation : block) - for (Value result : operation.getResults()) { - BufferAssignmentPositions positions = - computeAllocAndDeallocPositions(result); - os << "Positions for "; - result.print(os); - os << "\n Alloc: "; - positions.getAllocPosition()->print(os); - os << "\n Dealloc: "; - positions.getDeallocPosition()->print(os); - os << "\n"; - } - } - - private: - /// Finds a proper placement block to store alloc/dealloc node according to - /// the algorithm described at the top of the file. It supports dominator and - /// post-dominator analyses via template arguments. - template - Block* findPlacementBlock(Value value, const AliasesT& aliases, - const DominatorT& doms) const { - assert(!value.isa() && "Cannot place a block argument"); - // Start with the current block the value is defined in. - Block* dom = value.getDefiningOp()->getBlock(); - // Iterate over all aliases and their uses to find a safe placement block - // according to the given dominator information. - for (auto alias : aliases) { - for (auto user : alias.getUsers()) { - // Move upwards in the dominator tree to find an appropriate - // dominator block that takes the current use into account. - dom = doms.findNearestCommonDominator(dom, user->getBlock()); - } - } - return dom; - } - - /// Finds a proper alloc positions according to the algorithm described at the - /// top of the file. - template - Operation* getAllocPosition(Value value, const AliasesT& aliases) const { - // Determine the actual block to place the alloc and get liveness - // information. - auto placementBlock = findPlacementBlock(value, aliases, dominators); - auto livenessInfo = liveness.getLiveness(placementBlock); - - // We have to ensure that the alloc will be before the first use of all - // aliases of the given value. We first assume that there are no uses in the - // placementBlock and that we can safely place the alloc before the - // terminator at the end of the block. - Operation* startOperation = placementBlock->getTerminator(); - // Iterate over all aliases and ensure that the startOperation will point to - // the first operation of all potential aliases in the placementBlock. - for (auto alias : aliases) { - auto aliasStartOperation = livenessInfo->getStartOperation(alias); - // Check whether the aliasStartOperation lies in the desired block and - // whether it is before the current startOperation. If yes, this will be - // the new startOperation. - if (aliasStartOperation->getBlock() == placementBlock && - aliasStartOperation->isBeforeInBlock(startOperation)) { - startOperation = aliasStartOperation; - } - } - // startOperation is the first operation before which we can safely store - // the alloc taking all potential aliases into account. - return startOperation; - } - - /// Finds a proper dealloc positions according to the algorithm described at - /// the top of the file. - template - Operation* getDeallocPosition(Value value, const AliasesT& aliases) const { - // Determine the actual block to place the dealloc and get liveness - // information. - auto placementBlock = findPlacementBlock(value, aliases, postDominators); - auto livenessInfo = liveness.getLiveness(placementBlock); - - // We have to ensure that the dealloc will be after the last use of all - // aliases of the given value. We first assume that there are no uses in the - // placementBlock and that we can safely place the dealloc at the beginning. - Operation* endOperation = &placementBlock->front(); - // Iterate over all aliases and ensure that the endOperation will point to - // the last operation of all potential aliases in the placementBlock. - for (auto alias : aliases) { - auto aliasEndOperation = - livenessInfo->getEndOperation(alias, endOperation); - // Check whether the aliasEndOperation lies in the desired block and - // whether it is behind the current endOperation. If yes, this will be the - // new endOperation. - if (aliasEndOperation->getBlock() == placementBlock && - endOperation->isBeforeInBlock(aliasEndOperation)) { - endOperation = aliasEndOperation; - } - } - // endOperation is the last operation behind which we can safely store the - // dealloc taking all potential aliases into account. - return endOperation; - } - - /// The operation this transformation was constructed from. - Operation* operation; - - /// The underlying liveness analysis to compute fine grained information about - /// alloc and dealloc positions. - Liveness liveness; - - /// The dominator analysis to place allocs in the appropriate blocks. - DominanceInfo dominators; - - /// The post dominator analysis to place deallocs in the appropriate blocks. - PostDominanceInfo postDominators; - - /// The internal alias analysis to ensure that allocs and deallocs take all - /// their potential aliases into account. - BufferAssignmentAliasAnalysis aliases; -}; - -//===----------------------------------------------------------------------===// -// BufferAssignmentPass -//===----------------------------------------------------------------------===// - -/// The actual buffer assignment pass that moves alloc and dealloc nodes into -/// the right positions. It uses the algorithm described at the top of the file. -// TODO(dfki): create a templated version that allows to match dialect-specific -// alloc/dealloc nodes and to insert dialect-specific dealloc node. -struct BufferAssignmentPass - : mlir::PassWrapper { - void runOnFunction() override { - // Get required analysis information first. - auto& analysis = getAnalysis(); - - // Compute an initial placement of all nodes. - llvm::SmallDenseMap placements; - getFunction().walk([&](AllocOp alloc) { - placements[alloc] = analysis.computeAllocAndDeallocPositions(alloc); - }); - - // Move alloc (and dealloc - if any) nodes into the right places - // and insert dealloc nodes if necessary. - getFunction().walk([&](AllocOp alloc) { - // Find already associated dealloc nodes. - auto deallocs = analysis.findAssociatedDeallocs(alloc); - assert(deallocs.size() < 2 && - "Not supported number of associated dealloc operations"); - - // Move alloc node to the right place. - BufferAssignmentPositions& positions = placements[alloc]; - Operation* allocOperation = alloc.getOperation(); - allocOperation->moveBefore(positions.getAllocPosition()); - - // If there is an existing dealloc, move it to the right place. - if (deallocs.size()) { - Operation* nextOp = positions.getDeallocPosition()->getNextNode(); - assert(nextOp && "Invalid Dealloc operation position"); - (*deallocs.begin())->moveBefore(nextOp); - } else { - // If there is no dealloc node, insert one in the right place. - OpBuilder builder(alloc); - builder.setInsertionPointAfter(positions.getDeallocPosition()); - builder.create(allocOperation->getLoc(), alloc); - } - }); - }; -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// BufferAssignmentPlacer -//===----------------------------------------------------------------------===// - -/// Creates a new assignment placer. -BufferAssignmentPlacer::BufferAssignmentPlacer(Operation* op) - : operation(op), dominators(op) {} - -/// Computes the actual position to place allocs for the given value. -OpBuilder::InsertPoint BufferAssignmentPlacer::computeAllocPosition( - Value value) { - Operation* insertOp = value.getDefiningOp(); - assert(insertOp && "There is not a defining operation for the input value"); - OpBuilder opBuilder(insertOp); - return opBuilder.saveInsertionPoint(); -} - -//===----------------------------------------------------------------------===// -// FunctionAndBlockSignatureConverter -//===----------------------------------------------------------------------===// - -// Performs the actual signature rewriting step. -LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite( - FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const { - auto toMemrefConverter = [&](Type t) -> Type { - if (auto tensorType = t.dyn_cast()) { - return MemRefType::get(tensorType.getShape(), - tensorType.getElementType()); - } - return t; - }; - // Converting tensor-type function arguments to memref-type. - auto funcType = funcOp.getType(); - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) { - conversion.addInputs(argType.index(), toMemrefConverter(argType.value())); - } - for (auto resType : funcType.getResults()) { - conversion.addInputs(toMemrefConverter(resType)); - } - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), conversion); - }); - // Converting tensor-type block arugments of all blocks inside the - // function region to memref-type except for the entry block. - for (auto& block : funcOp.getBlocks()) { - if (block.isEntryBlock()) continue; - for (int i = 0, e = block.getNumArguments(); i < e; ++i) { - auto oldArg = block.getArgument(i); - auto newArg = - block.insertArgument(i, toMemrefConverter(oldArg.getType())); - oldArg.replaceAllUsesWith(newArg); - block.eraseArgument(i + 1); - } - } - return success(); -} - -/// A helper method to make the functions, whose all block argument types are -/// Memref or non-shaped type, legal. BufferAssignmentPlacer expects all -/// function and block argument types are in Memref or non-shaped type. Using -/// this helper method and additionally, FunctionAndBlockSignatureConverter as a -/// pattern conversion make sure that the type of block arguments are compatible -/// with using BufferAssignmentPlacer. -void FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp( - ConversionTarget& target) { - auto isLegalBlockArg = [](BlockArgument arg) -> bool { - auto type = arg.getType(); - return type.isa() || !type.isa(); - }; - target.addDynamicallyLegalOp([&](FuncOp funcOp) { - bool legality = true; - for (auto& block2 : funcOp.getBlocks()) { - legality &= llvm::all_of(block2.getArguments(), isLegalBlockArg); - if (!legality) break; - } - return legality; - }); -} - -//===----------------------------------------------------------------------===// -// Buffer assignment pass registrations -//===----------------------------------------------------------------------===// - -std::unique_ptr> createBufferAssignmentPass() { - return absl::make_unique(); -} - -static PassRegistration buffer_assignment_pass( - "buffer-assignment", - "Executes buffer assignment pass to automatically move alloc and dealloc " - "operations into their proper positions"); - -} // namespace xla -} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h deleted file mode 100644 index ced5769b44c..00000000000 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ - -#include "mlir/Analysis/Liveness.h" -#include "mlir/IR/Builders.h" // TF:llvm-project -#include "mlir/IR/Dominance.h" -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project - -namespace mlir { -namespace xla { - -/// Prepares a buffer assignment phase. It can place (user-defined) alloc -/// nodes. This simplifies the integration of the actual buffer-assignment -/// pass. Sample usage: -/// BufferAssignmentPlacer baHelper(regionOp); -/// -> determine alloc positions -/// auto allocPosition = baHelper.computeAllocPosition(value); -/// -> place alloc -/// allocBuilder.setInsertionPoint(positions.getAllocPosition()); -/// -/// alternatively: -/// -> place alloc -/// baHelper.insertAlloc(...); -/// Note: this class is intended to be used during legalization. In order -/// to move alloc and dealloc nodes into the right places you can use the -/// createBufferAssignmentPass() function. -class BufferAssignmentPlacer { - public: - /// Creates a new assignment builder. - explicit BufferAssignmentPlacer(Operation* op); - - /// Returns the operation this analysis was constructed from. - Operation* getOperation() const { return operation; } - - /// Computes the actual position to place allocs for the given value. - OpBuilder::InsertPoint computeAllocPosition(Value value); - - private: - /// The operation this analysis was constructed from. - Operation* operation; - - /// The dominator analysis to place allocs in the appropriate blocks. - DominanceInfo dominators; -}; - -/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer -/// instance. -template -class BufferAssignmentOpConversionPattern - : public OpConversionPattern { - public: - explicit BufferAssignmentOpConversionPattern( - MLIRContext* context_, - xla::BufferAssignmentPlacer* bufferAssignment_ = nullptr, - PatternBenefit benefit_ = 1) - : OpConversionPattern(context_, benefit_), - bufferAssignment(bufferAssignment_) {} - - protected: - xla::BufferAssignmentPlacer* bufferAssignment; -}; - -// Converts only the tensor-type function and block arguments to memref-type. -class FunctionAndBlockSignatureConverter - : public BufferAssignmentOpConversionPattern { - public: - using BufferAssignmentOpConversionPattern< - FuncOp>::BufferAssignmentOpConversionPattern; - - // Adding functions whose arguments are memref type to the set of legal - // operations. - static void addDynamicallyLegalFuncOp(ConversionTarget& target); - - // Performs the actual signature rewriting step. - LogicalResult matchAndRewrite( - FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final; -}; - -// This pattern converter transforms a non-void ReturnOpSourceTy into a void -// return of type ReturnOpTargetTy. It uses a copy operation of type CopyOpTy to -// copy the results to the output buffer. -template -class NonVoidToVoidReturnOpConverter - : public BufferAssignmentOpConversionPattern { - public: - using BufferAssignmentOpConversionPattern< - ReturnOpSourceTy>::BufferAssignmentOpConversionPattern; - - // Performs the actual return-op conversion step. - LogicalResult matchAndRewrite( - ReturnOpSourceTy returnOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - auto numReturnValues = returnOp.getNumOperands(); - auto funcOp = returnOp.template getParentOfType(); - auto numFuncArgs = funcOp.getNumArguments(); - auto loc = returnOp.getLoc(); - - // Find the corresponding output buffer for each operand. - for (auto operand : llvm::enumerate(operands)) { - auto returnArgNumber = numFuncArgs - numReturnValues + operand.index(); - auto dstBuffer = funcOp.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) { - continue; - } - - // Insert the copy operation to copy before the return. - rewriter.setInsertionPoint( - returnOp.getOperation()->getBlock()->getTerminator()); - rewriter.create(loc, operand.value(), - funcOp.getArgument(returnArgNumber)); - } - // Insert the new target return operation. - rewriter.replaceOpWithNewOp(returnOp); - return success(); - } -}; - -} // namespace xla -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc deleted file mode 100644 index 5a0d791079c..00000000000 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment_test.cc +++ /dev/null @@ -1,170 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for testing buffer assignment including its -// utility converters. - -#include "tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h" - -#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project -#include "mlir/IR/Function.h" // TF:llvm-project -#include "mlir/IR/Operation.h" // TF:llvm-project -#include "mlir/Pass/Pass.h" // TF:llvm-project -#include "mlir/Pass/PassManager.h" // TF:llvm-project -#include "absl/memory/memory.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" - -namespace mlir { -namespace xla { -namespace { -/// This pass tests two provided operation converters, -/// FunctionAndBlockSignatureConverter and NonVoidToVoidReturnOpConverter, for -/// Buffer Assignment. -struct BufferAssignmentPreparationTestPass - : mlir::PassWrapper { - /// This dialect independent unary operation has been defined only for testing - /// buffer assignment. - class BufferAssignmentTestUnaryOp - : public Op { - public: - using Op::Op; - static StringRef getOperationName() { - return "buffer_assignment_test.unary"; - } - static void build(OpBuilder& b, OperationState& state, Value source) { - state.addOperands(source); - } - }; - - /// This dialect independent lowered unary operation has been defined only for - /// testing buffer assignment. - class BufferAssignmentTestUnaryLoweredOp - : public Op::Impl> { - public: - using Op::Op; - static StringRef getOperationName() { - return "buffer_assignment_test.unary_lowered"; - } - static void build(OpBuilder& b, OperationState& state, Value source, - Value target) { - state.addOperands(source); - state.addOperands(target); - } - }; - - /// This dialect independent copy operation has been defined only for testing - /// NonVoidToVoidReturnOpConverter - class BufferAssignmentTestCopyOp - : public Op::Impl> { - public: - using Op::Op; - static StringRef getOperationName() { - return "buffer_assignment_test.copy"; - } - static void build(OpBuilder& b, OperationState& state, Value from, - Value to) { - state.addOperands(from); - state.addOperands(to); - } - }; - - /// A simple converter that legalizes a BufferAssignmentTestUnaryOp to a - /// BufferAssignmentTestUnaryLoweredOp and creates buffer allocation for - /// the result of the computation. - class TestUnaryOpConverter : public BufferAssignmentOpConversionPattern< - BufferAssignmentTestUnaryOp> { - public: - using BufferAssignmentOpConversionPattern< - BufferAssignmentTestUnaryOp>::BufferAssignmentOpConversionPattern; - - // Performs the actual legalization conversion step. - LogicalResult matchAndRewrite( - BufferAssignmentTestUnaryOp op, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - // Create a new buffer allocation using the current BufferAssignmentPlacer - // instance. - auto result = op.getResult(); - auto result_type = result.getType().dyn_cast(); - auto memref_type = - MemRefType::get(result_type.getShape(), result_type.getElementType()); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result)); - auto alloc = rewriter.create(op.getLoc(), memref_type); - - // Create the lowered operation and replace the old operation with a - // reference to the allocated buffer. - rewriter.create(op.getLoc(), - operands[0], alloc); - rewriter.replaceOp(op, {alloc}); - return success(); - } - }; - - void runOnFunction() override { - OwningRewritePatternList patterns; - auto funcOp = getOperation(); - auto context = funcOp.getContext(); - ConversionTarget target(*context); - BufferAssignmentPlacer bufferAssignmentPlacer(funcOp); - - // Specifying the legal and illegal operations. - context->allowUnregisteredDialects(true); - target.addIllegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - // TODO(dfki): ReturnOp can also be changed to TestReturnOp like - // BufferAssignmentTestCopyOp. - target.addDynamicallyLegalOp( - [](ReturnOp returnOp) { return returnOp.getNumOperands() == 0; }); - FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp(target); - - // Adding patterns for testing this pass. - // clang-format off - patterns.insert< - FunctionAndBlockSignatureConverter, - TestUnaryOpConverter, - NonVoidToVoidReturnOpConverter - - >(context, &bufferAssignmentPlacer); - // clang-format on - - if (failed(applyPartialConversion(funcOp, target, patterns, nullptr))) { - funcOp.emitOpError() - << "Failed to apply buffer assignment preparation steps"; - } - }; -}; -} // namespace - -/// This pass tests helper methods such as computeAllocPosition, -/// FunctionAndBlockSignatureConverter, NonVoidToVoidReturnOpConverter -/// conversion patterns. Furthermore, it checks buffer-assignment pass that -/// moves existing Alloc and Dealloc operations to their proper positions, and -/// insert missing Dealloc operations. -static PassPipelineRegistration<> buffer_assignment_test_pass( - "test-buffer-assignment", - "Tests buffer assignment helper methods and buffer assignment pass.", - [](mlir::OpPassManager& pm) { - pm.addPass(absl::make_unique()); - pm.addPass(createBufferAssignmentPass()); - }); - -} // namespace xla -} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc index a20511a95fc..e5a79616d5b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc @@ -33,24 +33,23 @@ namespace { // Converts binary ops that statically are determined to not broadcast directly // to the corresponding xla_hlo non-broadcasting op. template -struct ConvertTrivialNonBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { +struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { // Only rewrite for statically determinable non-broadcasting cases. - auto lhs = operands[0].getType().dyn_cast(); - auto rhs = operands[1].getType().dyn_cast(); - if (!lhs || !rhs) return failure(); + auto lhs_type = op.lhs().getType().template dyn_cast(); + auto rhs_type = op.rhs().getType().template dyn_cast(); + if (!lhs_type || !rhs_type) return failure(); // Requires rank broadcast. - if (lhs.getRank() != rhs.getRank()) return failure(); + if (lhs_type.getRank() != rhs_type.getRank()) return failure(); // Any dynamic dimension may require broadcasting and requires more // analysis. - if (!lhs.hasStaticShape() || !rhs.hasStaticShape()) return failure(); + if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) + return failure(); - for (auto extents : llvm::zip(lhs.getShape(), rhs.getShape())) { + for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) { auto lhs_extent = std::get<0>(extents); auto rhs_extent = std::get<1>(extents); if (lhs_extent != rhs_extent) { @@ -58,9 +57,8 @@ struct ConvertTrivialNonBroadcastBinaryOp } } - rewriter.replaceOp( - op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0], - operands[1], rewriter)}); + rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), + op.lhs(), op.rhs(), rewriter)}); return success(); } }; @@ -83,14 +81,13 @@ struct ConvertTrivialNonBroadcastBinaryOp // Whether that is of any practical benefit remains to be seen. template struct ConvertRankedDynamicBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { // Only support ranked operands. - Value lhs = operands[0]; - Value rhs = operands[1]; + Value lhs = op.lhs(); + Value rhs = op.rhs(); auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); auto result_type = @@ -166,8 +163,7 @@ struct HloBinaryElementwiseAdaptor { Value broadcasted_lhs, Value broadcasted_rhs, OpBuilder &builder) { return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs, - /*broadcast_dimensions=*/nullptr); + broadcasted_lhs, broadcasted_rhs); } }; @@ -186,9 +182,9 @@ struct HloCompareAdaptor { Type result_type, Value broadcasted_lhs, Value broadcasted_rhs, OpBuilder &builder) { - return builder.create( - from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, - /*broadcast_dimensions=*/nullptr, from_op.comparison_direction()); + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction()); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index aa29241048b..df92681cd97 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" @@ -39,16 +40,11 @@ namespace xla_hlo { namespace { constexpr StringRef kTempBufferAttr = "temp"; - -/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc. -Operation* FindInsertionPointForCopy(Value value) { - for (const auto& user : value.getUsers()) { - if (auto dealloc = dyn_cast(user)) { - return user; - } - } - return nullptr; -} +template +using BaseOpConversion = BufferAssignmentOpConversionPattern; +using StdReturnOpConverter = + BufferAssignmentReturnOpConverter; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -92,8 +88,9 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, return alloc; } -Value InsertAllocAndDealloc(Location loc, Value result, - ConversionPatternRewriter* rewriter) { +Value InsertAlloc(Location loc, OpResult result, + BufferAssignmentPlacer* bufferAssignment, + ConversionPatternRewriter* rewriter) { auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { result.getDefiningOp()->emitOpError() @@ -101,31 +98,21 @@ Value InsertAllocAndDealloc(Location loc, Value result, } auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); - - Operation* op = result.getDefiningOp(); - auto block = op->getBlock(); - - OpBuilder allocBuilder(op); - allocBuilder.setInsertionPointToStart(block); // Inserting at the beginning - auto alloc = allocBuilder.create(loc, memref_type); - - alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true)); - - allocBuilder.setInsertionPoint(block, std::prev(block->end())); - allocBuilder.create(loc, alloc); - + OpBuilder::InsertionGuard guard(*rewriter); + rewriter->restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result)); + auto alloc = rewriter->create(loc, memref_type); return alloc; } template -class HloToLhloOpConverter : public ConversionPattern { +class HloToLhloOpConverter : public BaseOpConversion { public: - explicit HloToLhloOpConverter(MLIRContext* context) - : ConversionPattern(HloOpTy::getOperationName(), 1, context) {} - + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + HloOpTy hloOp, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { + Operation* op = hloOp.getOperation(); const auto& original_results = op->getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : llvm::enumerate(original_results)) { @@ -135,8 +122,8 @@ class HloToLhloOpConverter : public ConversionPattern { return failure(); } if (resultType.hasStaticShape()) { - buffer_args.push_back( - InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter)); + buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(), + this->bufferAssignment, &rewriter)); } else { SmallVector results_shape; auto shape_type_op = dyn_cast(op); @@ -156,9 +143,9 @@ class HloToLhloOpConverter : public ConversionPattern { }; struct HloToLhloDynamicBroadcastInDimOpConverter - : public OpConversionPattern { + : public BaseOpConversion { public: - using OpConversionPattern::OpConversionPattern; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, @@ -175,10 +162,9 @@ struct HloToLhloDynamicBroadcastInDimOpConverter } }; -struct HloToLhloReduceOpConverter - : public OpConversionPattern { +struct HloToLhloReduceOpConverter : public BaseOpConversion { public: - using OpConversionPattern::OpConversionPattern; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( xla_hlo::ReduceOp op, ArrayRef operands, @@ -194,7 +180,8 @@ struct HloToLhloReduceOpConverter const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : original_results) { - buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter)); + buffer_args.push_back( + InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); } auto new_op = rewriter.create( loc, llvm::None, buffer_args, op.getAttrs()); @@ -230,12 +217,12 @@ struct HloToLhloReduceOpConverter } }; -class HloToLhloTensorLoadOpConverter : public ConversionPattern { +class HloToLhloTensorLoadOpConverter + : public BaseOpConversion { public: - explicit HloToLhloTensorLoadOpConverter(MLIRContext* context) - : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {} + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + mlir::TensorLoadOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOp(op, operands); return success(); @@ -243,13 +230,13 @@ class HloToLhloTensorLoadOpConverter : public ConversionPattern { }; // TODO(b/137624192): Rewrite into a copy and elide copy if possible. -class HloToLhloTensorStoreOpConverter : public ConversionPattern { +class HloToLhloTensorStoreOpConverter + : public BaseOpConversion { public: - explicit HloToLhloTensorStoreOpConverter(MLIRContext* context) - : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {} + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + mlir::TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOpWithNewOp( op, llvm::None, operands.front(), operands.back()); @@ -291,7 +278,6 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // "xla_lhlo.multiply"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// dealloc %0 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () // }) : () -> () // return @@ -313,14 +299,13 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %arg1: memref<4xf32>, // %arg2: memref<4xf32>) { // %0 = alloc() : memref<4xf32> -// %1 = alloc() : memref<4xf32> + // "xla_lhlo.maximum"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () +// %1 = alloc() : memref<4xf32> // "xla_lhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () -// dealloc %0 : memref<4xf32> -// dealloc %1 : memref<4xf32> // "xla_lhlo.terminator"() : () -> () // } @@ -337,7 +322,7 @@ struct HloLegalizeToLhlo target.addIllegalOp(); target.addIllegalOp(); target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); target.addIllegalDialect(); target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); @@ -346,101 +331,25 @@ struct HloLegalizeToLhlo }); auto module = getOperation(); - populateHLOToLHLOConversionPattern(module.getContext(), &patterns); - - // Do partial conversion so we can have unknown ops in tests. - if (failed(applyPartialConversion(module, target, patterns, nullptr))) { - signalPassFailure(); - } + BufferAssignmentTypeConverter converter; + module.walk([&](FuncOp func) { + BufferAssignmentPlacer bufferAssignment(func); + OwningRewritePatternList patterns; + populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, + &converter, &patterns); + return WalkResult( + applyPartialConversion(func, target, patterns, &converter)); + }); } }; - -Type ConvertType(Type t) { - if (auto tensorType = t.dyn_cast()) { - return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - } - return t; -} - } // namespace -/// Transforms FuncOp arguments and results from tensors to buffers. Tensor -/// results are converted to memrefs and appended to the argument list. -class HloToLhloFuncOpConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - if (funcOp.getBody().getBlocks().size() > 1) { - funcOp.emitOpError() << "tensor to buffer conversion expects a single " - "block in the region containing the operation"; - return failure(); - } - - auto funcType = funcOp.getType(); - - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) { - conversion.addInputs(argType.index(), ConvertType(argType.value())); - } - for (auto resType : funcType.getResults()) { - conversion.addInputs(ConvertType(resType)); - } - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), conversion); - }); - return success(); - } -}; - -/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each -/// result to the corresponding buffer argument. -class StdToLhloReturnOpConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::ReturnOp returnOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - auto numReturnValues = returnOp.getNumOperands(); - auto funcOp = returnOp.getParentOfType(); - auto numFuncArgs = funcOp.getNumArguments(); - auto loc = returnOp.getLoc(); - - for (auto operand : llvm::enumerate(operands)) { - auto returnArgNumber = numFuncArgs - numReturnValues + operand.index(); - auto dstBuffer = funcOp.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) { - continue; - } - - auto dealloc = FindInsertionPointForCopy(operand.value()); - - if (dealloc == nullptr) { - returnOp.emitOpError() - << "Missing dealloc for operand " << operand.index(); - return failure(); - } - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(dealloc); - rewriter.create(loc, llvm::None, operand.value(), - funcOp.getArgument(returnArgNumber)); - } - rewriter.replaceOpWithNewOp(returnOp); - return success(); - } -}; - -void populateHLOToLHLOConversionPattern(MLIRContext* context, - OwningRewritePatternList* patterns) { +void populateHLOToLHLOConversionPattern( + MLIRContext* context, BufferAssignmentPlacer* bufferAssignment, + TypeConverter* converter, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, - HloToLhloFuncOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -453,6 +362,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -472,8 +382,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloReduceOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter, - StdToLhloReturnOpConverter - >(context); + FunctionAndBlockSignatureConverter, + StdReturnOpConverter + >(context, bufferAssignment, converter); // clang-format on } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 129a24600a2..bb1169a57d6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -61,47 +61,46 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, return success(); } -LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) { - Operation* op_inst = conditional_op.getOperation(); - mlir::OpBuilder builder(conditional_op); +LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { + Operation* op_inst = if_op.getOperation(); + mlir::OpBuilder builder(if_op); auto orig_block = op_inst->getBlock(); auto* tail_block = orig_block->splitBlock(op_inst); - auto loc = conditional_op.getLoc(); + auto loc = if_op.getLoc(); // Duplicate the true and false regions in the block between the sections // before and after the conditional. BlockAndValueMapping mapper; - conditional_op.true_branch().cloneInto(orig_block->getParent(), - Region::iterator(tail_block), mapper); - conditional_op.false_branch().cloneInto(orig_block->getParent(), - Region::iterator(tail_block), mapper); + if_op.true_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); + if_op.false_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); // Determine the blocks for the start of the true and false regions. - Block* true_block = mapper.lookup(&conditional_op.true_branch().front()); - Block* false_block = mapper.lookup(&conditional_op.false_branch().front()); + Block* true_block = mapper.lookup(&if_op.true_branch().front()); + Block* false_block = mapper.lookup(&if_op.false_branch().front()); // Perform the conditional branch into the true/false cases. builder.setInsertionPointToEnd(orig_block); // Extract the predicate for checking branching, then branch to the true and // false regions appropriately. - auto cond_value = - builder.create(loc, conditional_op.pred()); + auto cond_value = builder.create(loc, if_op.pred()); builder.create(loc, cond_value, true_block, - conditional_op.true_arg(), false_block, - conditional_op.false_arg()); + if_op.true_arg(), false_block, + if_op.false_arg()); // Replace the true case's return operations with a branch to the tail of // the condition. - if (failed(ReplaceTerminators(&conditional_op.true_branch(), tail_block, loc, - mapper, &builder))) + if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper, + &builder))) return failure(); - if (failed(ReplaceTerminators(&conditional_op.false_branch(), tail_block, loc, - mapper, &builder))) + if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper, + &builder))) return failure(); - tail_block->addArguments(conditional_op.getResult().getType()); - conditional_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); + tail_block->addArguments(if_op.getResult().getType()); + if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); op_inst->erase(); return success(); @@ -210,11 +209,11 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { void LegalizeControlFlow::runOnFunction() { auto func = getFunction(); - llvm::SmallVector conditional_ops; - func.walk([&](ConditionalOp op) { conditional_ops.push_back(op); }); + llvm::SmallVector if_ops; + func.walk([&](IfOp op) { if_ops.push_back(op); }); - for (auto& op : conditional_ops) { - if (failed(LowerConditionalOp(op))) return signalPassFailure(); + for (auto& op : if_ops) { + if (failed(LowerIfOp(op))) return signalPassFailure(); } llvm::SmallVector while_ops; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index fb03c9b82e5..2d6da67fc60 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project @@ -43,9 +44,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -64,8 +67,9 @@ class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} - explicit LegalizeTF(bool allow_partial_conversion) { + explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) { allow_partial_conversion_ = allow_partial_conversion; + legalize_chlo_ = legalize_chlo; } /// Performs the lowering to XLA dialect. @@ -76,6 +80,11 @@ class LegalizeTF : public PassWrapper { *this, "allow-partial-conversion", llvm::cl::desc("Allow operations that can't be legalized."), llvm::cl::init(false)}; + Option legalize_chlo_{ + *this, "legalize-chlo", + llvm::cl::desc( + "Also legalizes intermediate chlo ops to hlo (default true)"), + llvm::cl::init(true)}; }; /// Returns if the given TF data format string is the default format. @@ -359,6 +368,174 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder); } +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Gets the resulting type from a broadcast between two types for statically +// shaped types. This is to be used for legacy lowerings that both use non +// left-padded broadcasting and static shapes. Its use should not be permitted +// in new code. +// May return nullptr on invalid static broadcast dimensions. +// ABSL_DEPRECATED() +static RankedTensorType GetStaticBroadcastType( + RankedTensorType x, RankedTensorType y, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto element_type = x.getElementType(); + auto shape_x = x.getShape(); + auto shape_y = y.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + out_shape[i] = std::max(x_val, y_val); + } + return RankedTensorType::get(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector broadcast_dimensions; + // Explicit broadcast dimensions. + for (const APInt &int_value : broadcast_dimensions_attr) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + return nullptr; + } + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (auto index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + out_shape[index_pair.value()] = std::max(old_value, new_value); + } + return RankedTensorType::get(out_shape, element_type); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Applies static binary broadcasting to a binary elementwise op. +// This is a legacy helper to provide general broadcasting support in legacy, +// static shaped code that relies on non-left-padded broadcasting semantics. +template +static Value StaticBinaryBroadcast(Location loc, Value x, Value y, + DenseIntElementsAttr broadcast_dims, + OpBuilder &builder) { + auto x_type = x.getType().cast(); + auto y_type = y.getType().cast(); + auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); + if (!result_type) { + emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type + << " with broadcast_dims = " << broadcast_dims; + return nullptr; + } + auto larger_broadcast_dims = + GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + if (x_type.getRank() < y_type.getRank()) { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, + larger_broadcast_dims); + } + } else { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, + larger_broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, broadcast_dims); + } + } + return builder.create(loc, x, y); +} + +// Gets a 1D tensor type suitable for expressing extents of the given tensor +// value type. If the value type is ranked, the result will be statically +// shaped. Otherwise, it will have a dynamic dimension. +static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { + Builder b(value_type.getContext()); + int64_t dim = value_type.hasRank() ? value_type.getRank() : -1; + return RankedTensorType::get({dim}, b.getIndexType()); +} + +// Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value' +// by assuming that the shape of the lower ranked is a broadcast compatible +// prefix of the higher ranked. +// Values must be RankedTensorType (this restriction derives from the +// broadcast_dimensions attribute on DynamicBroadcastInDim). +// +// Example: +// CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the +// lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style +// implicit broadcasting). +static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value, + Value lower_rank_value, OpBuilder &builder) { + Value higher_rank_shape = + builder.create(loc, higher_rank_value); + auto result_extents_type = + GetExtentsTensorTypeFor(higher_rank_value.getType().cast()); + Value result_extents = builder.create( + loc, result_extents_type, higher_rank_shape); + + auto lower_rank_type = lower_rank_value.getType().cast(); + auto lower_rank = lower_rank_type.getRank(); + auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder); + return builder.create( + loc, higher_rank_value.getType(), lower_rank_value, result_extents, + prefix_dims); +} + +// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D +// value (broadcast_from) along that feature dimension. This is a shortcut +// for the cases where a 1D tensor must be broadcast along a specific feature +// dimension, which can vary based on data layout, etc. +// +// The extent of `broadcast_from` dim0 must be equal to the extent of the +// feature_dim of `broadcast_to`. +// +// Example: +// [1x2x3x4], [2], 1 -> [1x2x3x4] +// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for +// consistency. Possibly also rename for clarity. +static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, + Value broadcast_from, int64_t feature_dim, + OpBuilder &builder) { + auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto to_type = broadcast_to.getType().cast(); + auto result_shape = builder.create(loc, broadcast_to); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + return builder.create( + loc, to_type, broadcast_from, result_extents, broadcast_dims); +} + +// Broadcasts `input` to the shape of `broadcast_to` value following +// TF::BroadcastTo semantics. +// +// Requires that input is a ranked tensor. +// +// TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp +// supports unranked inputs in the lowering. +static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, + OpBuilder &builder) { + auto result_shape = builder.create(loc, broadcast_to); + auto to_type = broadcast_to.getType().cast(); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + int64_t rank = input.getType().cast().getRank(); + auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); + return builder.create( + loc, to_type, input, result_extents, broadcast_dims); +} + // Creates a batch dot using xla_hlo::DotGeneralOp. Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs, bool transpose_rhs, int64_t num_batch_dims, @@ -404,8 +581,7 @@ static void BuildReduceBody(Type element_type, Region *body, Location loc = body->getLoc(); auto reducer = - builder->create(loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr); + builder->create(loc, block->getArgument(0), block->getArgument(1)); builder->create(loc, reducer.getResult()); } @@ -505,8 +681,7 @@ static void CreateWhile32(Location loc, int num_iterations, loc, builder->getI32IntegerAttr(num_iterations)); StringAttr compare_direction = StringAttr::get("LT", builder->getContext()); Value compare = builder->create( - loc, loop_iv, upper_limit, - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, loop_iv, upper_limit, compare_direction); builder->create(loc, compare); } @@ -536,9 +711,9 @@ static void CreateWhile32(Location loc, int num_iterations, // Increment the loop induction variable by one. auto one = builder->create(loc, builder->getI32IntegerAttr(1)); - auto no_broadcast_dims = GetI64ElementsAttr({}, builder); - auto plus_one = builder->create(loc, old_values[0], one, - no_broadcast_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); + auto plus_one = builder->create( + loc, old_values[0], one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); @@ -563,21 +738,6 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, GetFeatureDimension(format, input.getType().cast())); } -//===----------------------------------------------------------------------===// -// Bias op utilities. -//===----------------------------------------------------------------------===// - -// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. -// Requires input to have ranked tensor. -static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, - StringAttr format, - Value input) { - auto inputType = input.getType().cast(); - size_t featureDim = GetFeatureDimension(format, inputType); - RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64)); - return DenseIntElementsAttr::get(type, featureDim); -} - //===----------------------------------------------------------------------===// // MatMul op utilities. //===----------------------------------------------------------------------===// @@ -740,8 +900,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create( - loc, block->getArgument(0), block->getArgument(2), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(2), compare_direction); Value selected_input = builder->create( loc, input_type, compare, block->getArgument(0), block->getArgument(2)); @@ -857,8 +1016,7 @@ static void BuildSortComparisonBody(llvm::ArrayRef element_types, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create( - loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(1), compare_direction); builder->create(loc, compare); } @@ -897,6 +1055,27 @@ NamedAttribute GetConvDimensionNumbersAttr( feature_dim, spatial_dims, builder->getContext())); } +// Converts a TF::BiasAddOp to HLO. +// This differs from a normal TF::AddOp with respect to how the data_format +// is handled, which can optionally require a general broadcast of the +// 'bias' term in a way that is not compatible with the standard left-padded +// broadcast semantics (i.e. NCHW will broadcast into dimension 1). +// The correct 'bias' broadcast will be synthesized manually. +class ConvertBiasAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::BiasAddOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto feature_dim = GetFeatureDimension( + op.data_formatAttr(), op.value().getType().cast()); + auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(), + feature_dim, rewriter); + rewriter.replaceOpWithNewOp(op, op.value(), bias_broadcast); + return success(); + } +}; + // Converts the TensorFlow conv op in template to the generic HLO conv op by // converting TensorFlow op attributes to HLO op attributes. // @@ -1158,7 +1337,6 @@ class ConvertDiagPartOp : public OpRewritePattern { rewriter.getI64IntegerAttr(1)); Value compare = rewriter.create( op.getLoc(), iota0, iota1, - /*broadcast_dimensions=*/nullptr, StringAttr::get("EQ", rewriter.getContext())); Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), 0, &rewriter); @@ -1271,33 +1449,35 @@ class ConvertFusedBatchNormGradBase non_feature_dims.push_back(i); } auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); - auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter); - auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type); auto epsilon = rewriter.create( loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); - auto add_op = rewriter.create(loc, var, epsilon.getResult(), - no_broadcast_dims); + auto add_op = rewriter.create( + loc, var, epsilon.getResult(), scalar_broadcast_dims); + Value scratch1 = rewriter.create(loc, add_op); // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create(loc, act, mean, broadcast_dims); - auto weighted_grad = - rewriter.create(loc, grad, sub_op, no_broadcast_dims); + auto sub_op = rewriter.create( + loc, act, + Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); + auto weighted_grad = rewriter.create(loc, grad, sub_op); Value scratch2 = ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create(loc, op.scale(), scratch1, no_broadcast_dims); - x_backprop = - rewriter.create(loc, grad, scaled_grad, broadcast_dims); + rewriter.create(loc, op.scale(), scratch1); + x_backprop = rewriter.create( + loc, grad, + Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, + rewriter)); // scale_backprop = scratch2 * scratch1 - scale_backprop = - rewriter.create(loc, scratch1, scratch2, no_broadcast_dims); + scale_backprop = rewriter.create(loc, scratch1, scratch2); // offset_backprop = sum(y_backprop) offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); @@ -1393,7 +1573,7 @@ class ConvertFusedBatchNormV3Op auto factor_const_op = rewriter.create( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); - Value corrected_variance = rewriter.create( + Value corrected_variance = rewriter.create( op.getLoc(), batch_variance.getType(), batch_variance, factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); @@ -1413,24 +1593,26 @@ class ConvertFusedBatchNormV3Op rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); // new_running_mean = alpha * old_mean + beta * batch_mean. - auto alpha_mul_old_mean = rewriter.create( + auto alpha_mul_old_mean = rewriter.create( op.getLoc(), op.mean().getType(), alpha, op.mean(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_mean = rewriter.create( + auto beta_mul_batch_mean = rewriter.create( op.getLoc(), batch_mean.getType(), beta, batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); - batch_mean = rewriter.create( + batch_mean = rewriter.create( op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); // new_running_variance = alpha * old_variance + beta * batch_variance. - auto alpha_mul_old_variance = rewriter.create( + auto alpha_mul_old_variance = rewriter.create( op.getLoc(), op.variance().getType(), alpha, op.variance(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_variance = rewriter.create( - op.getLoc(), corrected_variance.getType(), beta, corrected_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); - corrected_variance = rewriter.create( + auto beta_mul_batch_variance = + rewriter.create( + op.getLoc(), corrected_variance.getType(), beta, + corrected_variance, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + corrected_variance = rewriter.create( op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, /*broadcast_dimensions=*/DenseIntElementsAttr()); } @@ -1583,10 +1765,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Divide by the number of elements in the window. Value divisor = GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); - auto batch_dims = - GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter); - Value result = rewriter.create(op.getLoc(), result_type, reduce, - divisor, batch_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + Value result = rewriter.create( + op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); // Convert back if we enlarged the element type's bitwidth. if (input_element_type != sum_element_type) @@ -1743,29 +1924,20 @@ class ConvertSigmoidOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SigmoidOp op, PatternRewriter &rewriter) const override { - auto operand = op.getOperand(); + Location loc = op.getLoc(); - auto scalar_one = rewriter.create( - op.getLoc(), - rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5)); + // Create constant half with shape and element type same as the operand. + Value operand = op.getOperand(); + auto operand_ty = operand.getType().cast(); + auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType()); + ElementsAttr attr = mlir::xla::getSplat(&rewriter, scalar_ty, 0.5); + auto scalar_half = rewriter.create(loc, attr); + auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter); - auto type = operand.getType().dyn_cast(); - if (!type) - return rewriter.notifyMatchFailure(op, "requires ranked tensor type"); - auto constant_ones = rewriter.create( - op.getLoc(), type, scalar_one, - GetI64ElementsAttr(type.getShape(), &rewriter)); - - auto scaled_input = rewriter.create( - op.getLoc(), operand, constant_ones, DenseIntElementsAttr()); - auto tanh_op = - rewriter.create(op.getLoc(), operand.getType(), scaled_input); - auto mul_op = - rewriter.create(op.getLoc(), tanh_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); - auto add_op = - rewriter.create(op.getLoc(), mul_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + auto scaled_input = rewriter.create(loc, operand, half); + auto tanh_op = rewriter.create(loc, scaled_input); + auto mul_op = rewriter.create(loc, tanh_op, half); + auto add_op = rewriter.create(loc, mul_op, half); rewriter.replaceOp(op, add_op.getResult()); return success(); @@ -1804,20 +1976,18 @@ class ConvertSoftmaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Value logits = op.logits(); - // Softmax converter requires ranked type because the XLA reduce ops used // while lowering requires dimensions attribute to reduce along. + // Note that the input and output shape is equivalent, so we use 'logits' + // and its type for shape calculations. + Value logits = op.logits(); RankedTensorType type = logits.getType().dyn_cast(); if (!type) return failure(); - auto loc = op.getLoc(); int rank = type.getRank(); // Note that the TensorFlow Softmax op verifies that the input rank is - // greater than or equal to one so both of the following sequences are - // valid. - auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter); + // greater than or equal to one so the following sequence is valid. auto reduce_dim = rewriter.create( loc, GetI64ElementsAttr({rank - 1}, &rewriter)); @@ -1830,8 +2000,10 @@ class ConvertSoftmaxOp : public OpRewritePattern { auto max_logits = rewriter.create(loc, logits, reduce_dim, /*keep_dims=*/rewriter.getBoolAttr(false)); - auto shifted_logits = - rewriter.create(loc, type, logits, max_logits, batch_dims); + auto max_logits_broadcast = + CommonPrefixBroadcast(loc, logits, max_logits, rewriter); + auto shifted_logits = rewriter.create(loc, type, logits, + max_logits_broadcast); // Exponentiate the inputs. Value exp = rewriter.create(loc, type, shifted_logits); @@ -1844,9 +2016,12 @@ class ConvertSoftmaxOp : public OpRewritePattern { if (use_log) { Value log = rewriter.create(loc, sum); - rewriter.replaceOpWithNewOp(op, shifted_logits, log, batch_dims); + auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter); + rewriter.replaceOpWithNewOp(op, shifted_logits, + log_broadcast); } else { - rewriter.replaceOpWithNewOp(op, exp, sum, batch_dims); + auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter); + rewriter.replaceOpWithNewOp(op, exp, sum_broadcast); } return success(); } @@ -1893,7 +2068,7 @@ class ConvertSizeOp : public OpRewritePattern { auto dim = rewriter.create( op.getLoc(), result_type, input, rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); - size = rewriter.create( + size = rewriter.create( op.getLoc(), size->getResult(0), dim.getResult(), /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } @@ -2579,16 +2754,31 @@ class ConvertRangeOp : public OpRewritePattern { auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.delta(), xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); } }; +ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { + auto int_attr = attr.cast(); + auto type = val.getType().cast(); + + SmallVector axis; + axis.reserve(int_attr.getNumElements()); + + int64_t rank = type.getRank(); + for (auto val : int_attr.getValues()) { + axis.push_back((val.getSExtValue() + rank) % rank); + } + + return builder->getI64TensorAttr(axis); +} + /// Converts the LinSpace tensorflow op to a xla_hlo.iota op with a scaling /// and offset applied to generate the linspace values. The output tensor needs /// to have a static shape. The implementation is defined in C++ because there @@ -2615,7 +2805,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { int64_t num = (*num_attr.begin()).getSExtValue(); // Calculate the scaling that needs to be applied to the iota. - auto step_numerator = rewriter.create( + auto step_numerator = rewriter.create( op.getLoc(), op.start().getType(), op.stop(), op.start(), xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); Value step_denominator = rewriter.create( @@ -2623,11 +2813,11 @@ class ConvertLinSpaceOp : public OpRewritePattern { if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); - step_denominator = rewriter.create( + step_denominator = rewriter.create( op.getLoc(), step_denominator.getType(), step_denominator, one, xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); } - auto step = rewriter.create( + auto step = rewriter.create( op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, xla::getBroadcastDimensionsAttr(&rewriter, step_numerator, step_denominator)); @@ -2635,10 +2825,10 @@ class ConvertLinSpaceOp : public OpRewritePattern { // Scale the iota and add the offset. auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, step, xla::getBroadcastDimensionsAttr(&rewriter, iota, step)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -2714,8 +2904,8 @@ class GenericConvertReductionOp : public OpRewritePattern { auto divisor = GetScalarConstOfType(reduce_element_type, loc, divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); - result = rewriter.create(loc, result, divisor.getResult(), - broadcast_dims); + result = rewriter.create( + loc, result, divisor.getResult(), broadcast_dims); } result = rewriter.create(loc, result, element_type); @@ -3100,7 +3290,6 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { auto reducer = rewriter.create( loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, StringAttr::get("GE", rewriter.getContext())); rewriter.create(loc, reducer.getResult()); } @@ -3526,13 +3715,20 @@ class ConvertOneHotOp : public OpRewritePattern { output_dims.insert(output_dims.begin() + axis, depth); Location loc = op.getLoc(); + + // The iota result is the effective output shape of the computation, + // and indices must be broadcast into it. At this point, this computation + // would need to be reworked quite a bit to support dynamic shapes, so + // just using static broadcasting. auto index_type = RankedTensorType::get(output_dims, element_type); - Value compare = rewriter.create( - loc, op.indices(), - rewriter.create( - loc, index_type, - IntegerAttr::get(rewriter.getIntegerType(64), axis)), - GetI64ElementsAttr(broadcast_dims, &rewriter), + auto iota = rewriter.create( + loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); + auto broadcast_indices = rewriter.create( + loc, index_type, op.indices(), + GetI64ElementsAttr(broadcast_dims, &rewriter)); + + Value compare = rewriter.create( + loc, broadcast_indices, iota, StringAttr::get("EQ", rewriter.getContext())); Value on_value = rewriter.create( loc, op.getType(), op.on_value(), @@ -4181,6 +4377,68 @@ class ConvertXlaShardingOp : public OpRewritePattern { } }; +// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO. +class ConvertInplaceUpdateOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, + PatternRewriter &rewriter) const override { + auto input = op.x(); + auto indices = op.i(); + auto updates = op.v(); + + // Slice each row of `i` and `v` to perform a separate dynamic-update-slice + // on the contents of `x`. + auto input_type = input.getType().cast(); + auto updates_type = updates.getType().cast(); + auto indices_type = indices.getType().cast(); + if (!indices_type.hasStaticShape()) return failure(); + + if (indices_type.getRank() != 1) return failure(); + + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), + RankedTensorType::get({}, indices_type.getElementType())); + auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(64), 0); + auto unpacked_indices = rewriter.create( + op.getLoc(), unpacked_indices_type, indices, zero_attr); + + SmallVector split_updates_shape; + split_updates_shape.append(updates_type.getShape().begin(), + updates_type.getShape().end()); + split_updates_shape.front() = 1; + SmallVector split_updates_type; + split_updates_type.resize( + updates_type.getShape().front(), + RankedTensorType::get(split_updates_shape, + updates_type.getElementType())); + + auto cst = + rewriter.create(op.getLoc(), zero_attr).getResult(); + auto split_updates = rewriter.create( + op.getLoc(), split_updates_type, cst, updates); + + SmallVector input_indices; + input_indices.resize(input_type.getRank(), cst); + + SmallVector starts(updates_type.getRank(), 0); + SmallVector strides(updates_type.getRank(), 1); + SmallVector limits(updates_type.getShape().begin(), + updates_type.getShape().end()); + + for (auto pair : + llvm::zip(unpacked_indices.output(), split_updates.output())) { + input_indices.front() = std::get<0>(pair); + input = rewriter.create( + op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO. class ConvertXlaDynamicUpdateSliceOp : public OpRewritePattern { @@ -4316,7 +4574,6 @@ class ConvertQrOp : public OpRewritePattern { rewriter.getI64IntegerAttr(1)); Value compare = rewriter.create( op.getLoc(), iota0, iota1, - /*broadcast_dimensions=*/nullptr, StringAttr::get("EQ", rewriter.getContext())); Value identity_matrix = rewriter.create(op.getLoc(), compare, type.getElementType()); @@ -4350,8 +4607,7 @@ class ConvertQrOp : public OpRewritePattern { batch_dims.size(), precision_config, &rewriter); a_update = BatchDot(op.getLoc(), y, false, a_update, false, batch_dims.size(), precision_config, &rewriter); - a_panel = rewriter.create(op.getLoc(), a_panel, a_update, - /*broadcast_dimensions=*/nullptr); + a_panel = rewriter.create(op.getLoc(), a_panel, a_update); a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k}, &rewriter); @@ -4362,8 +4618,7 @@ class ConvertQrOp : public OpRewritePattern { batch_dims.size(), precision_config, &rewriter); q_update = BatchDot(op.getLoc(), q_update, false, y, true, batch_dims.size(), precision_config, &rewriter); - q_panel = rewriter.create(op.getLoc(), q_panel, q_update, - /*broadcast_dimensions=*/nullptr); + q_panel = rewriter.create(op.getLoc(), q_panel, q_update); q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter); } // full_matrices is false when only a partial result in needed. Slice to the @@ -4425,34 +4680,31 @@ class ConvertQrOp : public OpRewritePattern { Value iota = builder->create( loc, RankedTensorType::get({m}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value gtk = builder->create( + Value gtk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("GT", builder->getContext())); gtk = builder->create(loc, gtk, x_type.getElementType()); - Value x_after_k = builder->create( + Value x_after_k = builder->create( loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder)); - Value x_after_k_sq = builder->create( - loc, x_after_k, x_after_k, /*broadcast_dimensions=*/nullptr); + Value x_after_k_sq = builder->create(loc, x_after_k, x_after_k); // sigma = np.dot(x[k+1:], x[k+1:]) auto sigma = builder->create( loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder)); BuildReduceBody(x_type.getElementType(), &sigma.body(), builder); // mu = np.sqrt(x[k]*x[k] + sigma) - Value alpha_sq = builder->create(loc, alpha, alpha, - /*broadcast_dimensions=*/nullptr); + Value alpha_sq = builder->create(loc, alpha, alpha); Value mu = builder->create( - loc, builder->create(loc, alpha_sq, sigma.getResult(0), - /*broadcast_dimensions=*/nullptr)); + loc, builder->create(loc, alpha_sq, sigma.getResult(0))); - Value sigma_is_zero = builder->create( + Value sigma_is_zero = builder->create( loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); - Value alpha_is_negative = builder->create( + Value alpha_is_negative = builder->create( loc, alpha, zero, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); auto batch_size_one = builder->create( loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder)); - Value signed_mu = builder->create( + Value signed_mu = builder->create( loc, builder->create(loc, mu.getType(), alpha_is_negative, batch_size_one, @@ -4461,21 +4713,16 @@ class ConvertQrOp : public OpRewritePattern { *beta = builder->create(loc, alpha.getType(), sigma_is_zero, alpha, signed_mu); *tau = builder->create( - loc, - builder->create(loc, *beta, alpha, - /*broadcast_dimensions=*/nullptr), - *beta, - /*broadcast_dimensions=*/nullptr); + loc, builder->create(loc, *beta, alpha), *beta); Value zero_tau = builder->create( loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder)); *tau = builder->create(loc, alpha.getType(), sigma_is_zero, zero_tau, *tau); - Value divisor = builder->create(loc, alpha, *beta, - /*broadcast_dimensions=*/nullptr); + Value divisor = builder->create(loc, alpha, *beta); divisor = builder->create(loc, divisor.getType(), sigma_is_zero, batch_size_one, divisor); - Value eqk = builder->create( + Value eqk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); eqk = builder->create(loc, eqk, x_type.getElementType()); @@ -4488,10 +4735,12 @@ class ConvertQrOp : public OpRewritePattern { // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = builder->create( + // Note that the add performs a degenerate broadcast. + *v = builder->create( loc, e_k, - builder->create(loc, x_after_k, divisor, - GetI64ElementsAttr(batch_dim_ids, builder)), + StaticBinaryBroadcast(loc, x_after_k, divisor, + GetI64ElementsAttr(batch_dim_ids, builder), + *builder), /*broadcast_dimensions=*/nullptr); } @@ -4565,10 +4814,10 @@ class ConvertQrOp : public OpRewritePattern { precision, builder); vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims, precision, builder); - auto tau_x_vva = builder->create( - loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder)); - a = builder->create(loc, a, tau_x_vva, - /*broadcast_dimensions=*/nullptr); + auto tau_x_vva = StaticBinaryBroadcast( + loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); + a = builder->create(loc, a, tau_x_vva); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. @@ -4577,12 +4826,12 @@ class ConvertQrOp : public OpRewritePattern { auto iota = builder->create( loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value predecessor_mask = builder->create( + Value predecessor_mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); predecessor_mask = builder->create(loc, predecessor_mask, a_type.getElementType()); - Value mask = builder->create( + Value mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); mask = builder->create(loc, mask, a_type.getElementType()); @@ -4594,14 +4843,14 @@ class ConvertQrOp : public OpRewritePattern { mask, GetI64ElementsAttr(llvm::SmallVector(num_batch_dims, 1), builder)); - Value predecessor_masked_x = builder->create( + Value predecessor_masked_x = StaticBinaryBroadcast( loc, x, predecessor_mask, - GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder)); - Value masked_beta = builder->create( - loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder)); + GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder); + Value masked_beta = StaticBinaryBroadcast( + loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); Value new_x = - builder->create(loc, predecessor_masked_x, masked_beta, - /*broadcast_dimensions=*/nullptr); + builder->create(loc, predecessor_masked_x, masked_beta); // Update a[:,j] llvm::SmallVector dim_ids(num_dims); std::iota(dim_ids.begin(), dim_ids.end(), 0); @@ -4612,7 +4861,7 @@ class ConvertQrOp : public OpRewritePattern { loc, RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)), builder->getI64IntegerAttr(minor_dim + 1)); - Value xa_mask = builder->create( + Value xa_mask = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); a = builder->create(loc, a_type, xa_mask, new_x, a); @@ -4628,11 +4877,11 @@ class ConvertQrOp : public OpRewritePattern { builder)); auto vs_update = builder->create( loc, vs.getType(), xa_mask, - builder->create( - loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder)), + StaticBinaryBroadcast( + loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder), + *builder), vs_zeros); - vs = builder->create(loc, vs, vs_update, - /*broadcast_dimensions=*/nullptr); + vs = builder->create(loc, vs, vs_update); // taus[j] = tau llvm::SmallVector tau_broadcast_dims(batch_dims.size()); @@ -4649,17 +4898,16 @@ class ConvertQrOp : public OpRewritePattern { loc, taus.getType(), taus_zeros, GetI64ElementsAttr(taus.getType().cast().getShape(), builder)); - Value taus_mask = builder->create( + Value taus_mask = builder->create( loc, iota_n, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); auto taus_update = builder->create( loc, taus.getType(), taus_mask, - builder->create( + StaticBinaryBroadcast( loc, taus_zeros, tau, - GetI64ElementsAttr(tau_broadcast_dims, builder)), + GetI64ElementsAttr(tau_broadcast_dims, builder), *builder), taus_zeros); - taus = builder->create(loc, taus, taus_update, - /*broadcast_dimensions=*/nullptr); + taus = builder->create(loc, taus, taus_update); new_values->assign({a, vs, taus}); }; @@ -4716,8 +4964,7 @@ class ConvertQrOp : public OpRewritePattern { j = builder->create( loc, j, GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1, - builder), - /*broadcast_dimensions=*/nullptr); + builder)); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder); // beta has shape [..., 1] @@ -4736,7 +4983,7 @@ class ConvertQrOp : public OpRewritePattern { loc, vs.getType(), zero, GetI64ElementsAttr(vs.getType().cast().getShape(), builder)); - auto compare = builder->create( + auto compare = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("GE", builder->getContext())); auto y = builder->create(loc, vs.getType(), compare, zero, vs); @@ -4751,13 +4998,12 @@ class ConvertQrOp : public OpRewritePattern { // z = -beta * (v + wyv) auto neg_beta = builder->create(loc, beta); - auto v_wyv = builder->create(loc, v, wyv, - /*broadcast_dimensions=*/nullptr); + auto v_wyv = builder->create(loc, v, wyv); auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); beta_broadcast_dims.push_back(n_index); - auto z = builder->create( + auto z = StaticBinaryBroadcast( loc, neg_beta, v_wyv, - GetI64ElementsAttr(beta_broadcast_dims, builder)); + GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter); w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder); new_values->assign({w, vs, taus}); @@ -4775,8 +5021,9 @@ class ConvertQrOp : public OpRewritePattern { auto neg_beta = rewriter->create(loc, beta); auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); beta_broadcast_dims.push_back(n_index); - auto bv = rewriter->create( - loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter)); + auto bv = StaticBinaryBroadcast( + loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter), + *rewriter); w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter); SmallVector while_output; @@ -4785,9 +5032,55 @@ class ConvertQrOp : public OpRewritePattern { } }; +// Emits debug information which includes the number of ops of each type which +// failed to legalize. +void EmitLegalizationErrors(Operation *op, + const DenseSet &nonlegalized_ops) { + // Track the legalization failures by mapping op name to information about + // that failure: the number of unlegalized occurances of the op, and one + // example operation that failed. + std::map> op_name_to_error_info; + DenseSet error_ops; + for (Operation *nonlegalized_op : nonlegalized_ops) { + // Increment count of this legalization failure. + StringRef op_name = nonlegalized_op->getName().getStringRef(); + // If this emplace is successful, it's the first time we've encountered + // this op type. Initialize count to 0 so that after increment, it is 1. + auto insertion_result = op_name_to_error_info.emplace( + op_name, std::make_pair(0, nonlegalized_op)); + ++insertion_result.first->second.first; + } + std::vector error_messages; + error_messages.reserve(op_name_to_error_info.size()); + for (const auto &op_info : op_name_to_error_info) { + error_messages.push_back( + llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first)); + } + Location loc = op->getLoc(); + emitError(loc) << "The following operations cannot be legalized: " + << llvm::join(error_messages, "; ") + << ". These legalization failure(s) may be due to missing TF " + "to HLO lowerings and/or unsupported attributes, etc."; + // Emit more information about the missing ops. This error message + // contains useful details beyond the op name (input and output shapes, + // attributes, etc.). + if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) { + emitError(loc) + << "Emitting more detail about one op that failed to legalize..."; + } else if (VLOG_IS_ON(1)) { + emitError(loc) << "Emitting more detail about one of each type of op " + "that failed to legalize..."; + } + for (const auto &op_info : op_name_to_error_info) { + op_info.second.second->emitOpError() << "is not legalizable"; + if (!VLOG_IS_ON(1)) break; + } +} + // Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) + if (failed( + legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_))) signalPassFailure(); } @@ -4798,7 +5091,8 @@ static PassRegistration pass( #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" -LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { +LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, + bool legalize_chlo) { MLIRContext *context = op->getContext(); // Add lowering patterns to the list. @@ -4811,13 +5105,14 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, - ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2DOp, - ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, - ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, - ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp, - ConvertEinsumOp, ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, + ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, + ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, + ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, + ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, @@ -4831,7 +5126,18 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp, ConvertXlaDynamicUpdateSliceOp>(op->getContext()); + // Populate with CHLO->HLO lowerings to account for TF ops legalized to + // CHLO first. + if (legalize_chlo) { + xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + } + ConversionTarget target(*context); + if (legalize_chlo) { + target.addIllegalDialect(); + } else { + target.addLegalDialect(); + } target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); @@ -4841,15 +5147,24 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { if (!allow_partial_conversion) { // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. target.addLegalOp(); - return applyFullConversion(op, target, patterns); + DenseSet nonlegalized_ops; + LogicalResult result = applyPartialConversion( + op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops); + // In order to enforce that the conversion result is fully converted, + // fail if there are any nonlegalized ops in the set. + if (failed(result) || !nonlegalized_ops.empty()) { + EmitLegalizationErrors(op, nonlegalized_ops); + return failure(); + } + return result; } return applyPartialConversion(op, target, patterns); } std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion) { - return std::make_unique(allow_partial_conversion); + bool allow_partial_conversion, bool legalize_chlo) { + return std::make_unique(allow_partial_conversion, legalize_chlo); } } // end namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 86927fe0e07..d5e5b6f5a71 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -66,7 +66,7 @@ createLegalizeTFControlFlowPass() { namespace { void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { - // De-tuple the results of the xla hlo conditional result. + // De-tuple the results of the xla hlo if result. for (auto result_it : llvm::enumerate(replace)) { auto get_tuple_value = builder->create( result_it.value().getLoc(), tuple, result_it.index()); @@ -74,14 +74,13 @@ void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { } } -// Imports the source region into the destination region. The XLA conditional +// Imports the source region into the destination region. The XLA if // operation only supports one argument per branch. Therefore any branch that // requires additional arguments requires their values be tupled together. Then, // to support multiple returns (as XLA only supports a single return value) the -// results of the conditional are tupled together. +// results of the if operation are tupled together. void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc, bool tuple_return = true) { - BlockAndValueMapping mapper; OpBuilder builder(dest_region); auto entry_block = builder.createBlock(dest_region); @@ -111,27 +110,52 @@ void LowerIf(TF::IfOp op, ModuleOp module) { // XLA prefers tuple arguments for control flow due to XLA not supporting // multiple return values. SmallVector inputs(op.input()); - builder.setInsertionPoint(op); auto tuple_input = builder.create(loc, inputs); - // Create the new conditional op with tuple inputs. - SmallVector operands(op.getOperands()); + // Create the new if op with tuple inputs. auto result_type = builder.getTupleType(op.getResultTypes()); - auto conditional = builder.create( - loc, result_type, op.cond(), tuple_input, tuple_input); + auto if_op = builder.create(loc, result_type, op.cond(), + tuple_input, tuple_input); // Import the regions for both the true and false cases. These regions // must be updated to tuple the return results together and use the xla hlo // return op. - BlockAndValueMapping mapper; auto then_branch = module.lookupSymbol(op.then_branch()); auto else_branch = module.lookupSymbol(op.else_branch()); - ImportXlaRegion(then_branch, &conditional.true_branch(), loc); - ImportXlaRegion(else_branch, &conditional.false_branch(), loc); + ImportXlaRegion(then_branch, &if_op.true_branch(), loc); + ImportXlaRegion(else_branch, &if_op.false_branch(), loc); - // De-tuple the results of the xla hlo conditional result. - builder.setInsertionPointAfter(op); - Detuple(conditional.getResult(), op.getResults(), &builder); + // De-tuple the results of the xla hlo if result. + Detuple(if_op.getResult(), op.getResults(), &builder); + op.erase(); +} + +void LowerCase(TF::CaseOp op, ModuleOp module) { + Location loc = op.getLoc(); + OpBuilder builder(op); + + // XLA requires one argument per branch so we create a tuple of inputs to pass + // to each branch. + SmallVector inputs(op.input()); + auto tuple_input = builder.create(loc, inputs); + + // Create replica of input tuple for each branch + SmallVector n_tuple_inputs(op.branches().size(), tuple_input); + + // Create the new case op with tuple inputs. + auto case_op = builder.create( + loc, op.getResultTypes(), op.branch_index(), n_tuple_inputs, + op.branches().size()); + + // Import the regions for all branches. + for (unsigned i = 0; i < op.branches().size(); ++i) { + mlir::FuncOp branch_func = module.lookupSymbol( + op.branches()[i].cast()); + ImportXlaRegion(branch_func, &case_op.branches()[i], loc, + /*tuple_return=*/false); + } + + op.replaceAllUsesWith(case_op.getResults()); op.erase(); } @@ -146,7 +170,6 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { Value tuple_input = builder.create(loc, inputs); // Create the new while op with tuple inputs. - SmallVector operands(op.getOperands()); auto while_op = builder.create( loc, builder.getTupleType(op.getResultTypes()), tuple_input); @@ -159,7 +182,6 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { ImportXlaRegion(cond_branch, &while_op.cond(), loc, /*tuple_return=*/false); // De-tuple the results of the xla hlo while. - builder.setInsertionPointAfter(op); Detuple(while_op.getResult(), op.getResults(), &builder); op.erase(); } @@ -168,8 +190,20 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { void LegalizeTFControlFlow::runOnOperation() { auto module = getOperation(); - module.walk([&](TF::WhileOp op) -> void { LowerWhile(op, module); }); - module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); }); + module.walk([&](Operation* op) { + if (auto while_op = dyn_cast(op)) { + LowerWhile(while_op, module); + return; + } + if (auto if_op = dyn_cast(op)) { + LowerIf(if_op, module); + return; + } + if (auto case_op = dyn_cast(op)) { + LowerCase(case_op, module); + return; + } + }); } } // namespace xla_hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index b2a7c1e7f62..ef5a8356a32 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -18,6 +18,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; @@ -72,18 +73,6 @@ def : Pattern< // HLO and XLA doesn't support Assertions. def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; -//===----------------------------------------------------------------------===// -// Bias op patterns. -//===----------------------------------------------------------------------===// -def BiasAddFeatureDimension : NativeCodeCall< - "getBiasFeatureDimension($_builder, $0, $1)">; - -// $input needs to be a ranked tensor to identify index of the feature -// dimension depending on the data_format 'NHWC' or 'NCHW'. -def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format), - (HLO_AddOp $input, $bias, - (BiasAddFeatureDimension $data_format, $input))>; - //===----------------------------------------------------------------------===// // Binary op patterns. //===----------------------------------------------------------------------===// @@ -96,21 +85,22 @@ class DirectBinaryPat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; -foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], - [TF_AddV2Op, HLO_AddOp], - [TF_DivOp, HLO_DivOp], - [TF_LeftShiftOp, HLO_ShiftLeftOp], - [TF_MaximumOp, HLO_MaxOp], - [TF_MinimumOp, HLO_MinOp], - [TF_MulOp, HLO_MulOp], - [TF_PowOp, HLO_PowOp], - [TF_RealDivOp, HLO_DivOp], - [TF_SubOp, HLO_SubOp]] in +foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], + [TF_AddV2Op, HLOClient_BroadcastAddOp], + [TF_DivOp, HLOClient_BroadcastDivOp], + [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp], + [TF_MaximumOp, HLOClient_BroadcastMaxOp], + [TF_MinimumOp, HLOClient_BroadcastMinOp], + [TF_MulOp, HLOClient_BroadcastMulOp], + [TF_PowOp, HLOClient_BroadcastPowOp], + [TF_RealDivOp, HLOClient_BroadcastDivOp], + [TF_SubOp, HLOClient_BroadcastSubOp]] in def : DirectBinaryPat; def LowerRightShiftSigned : Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastShiftRightArithmeticOp $l, $r, + (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $r)]>; // TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op @@ -122,10 +112,11 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>; // // return floor(div(x, y)) def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))), + (HLO_FloorOp + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; -// Performs a substitution of FloorDir for integer tensors, which required +// Performs a substitution of FloorDiv for integer tensors, which required // additional correction for a negative numerator / denominator. Equivalent // pseudocode is shown below: // @@ -144,19 +135,19 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), // NOTE: This should be optimized for unsigned integers. // Requires static shaped inputs to create constant splats and computation of // broadcast attributes. -def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), +def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_SelectOp - (HLO_CompareOp - (HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)), + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)), + (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ), - (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)), - (HLO_DivOp - (HLO_NegOp:$neg (HLO_AddOp (HLO_AbsOp $l), - (HLO_SubOp (HLO_AbsOp $r), - (HLO_ConstOp (ConstantSplat<"1"> $r)), + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastDivOp + (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), + (HLOClient_BroadcastSubOp (HLO_AbsOp $r), + (HLO_ConstOp (GetScalarOfType<1> $r)), (NullDenseIntElementsAttr)), (BinBroadcastDimensions $l, $r))), (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), @@ -169,22 +160,22 @@ def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y // Requires static shaped inputs to create constant splats and computation of // broadcast attributes. -def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), +def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), - (HLO_ConstOp:$l_zeros (ConstantSplat<"0"> $l)), + (HLOClient_BroadcastAndOp + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), + (HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)), (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE), - (HLO_CompareOp - (HLO_CompareOp:$r_cmp $r, - (HLO_ConstOp:$r_zeros (ConstantSplat<"0"> $r)), + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp:$r_cmp $r, + (HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp:$rem_cmp $rem, $r_zeros, + (HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE), (NullDenseIntElementsAttr)), - (HLO_AddOp $r, + (HLOClient_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; //===----------------------------------------------------------------------===// @@ -196,10 +187,10 @@ class DirectLogicalBinaryPat (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $l)]>; -foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], - [TF_LogicalOrOp, HLO_OrOp], - [TF_BitwiseOrOp, HLO_OrOp], - [TF_BitwiseAndOp, HLO_AndOp]] in +foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp], + [TF_LogicalOrOp, HLOClient_BroadcastOrOp], + [TF_BitwiseOrOp, HLOClient_BroadcastOrOp], + [TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in def : DirectLogicalBinaryPat; //===----------------------------------------------------------------------===// @@ -208,7 +199,8 @@ foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], class DirectComparePat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction)>; + (HLOClient_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction)>; def : DirectComparePat; def : DirectComparePat; @@ -218,7 +210,8 @@ def : DirectComparePat; class EqualityPat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r, TrueBoolAttr:$incompatible_shape_error), - (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction), + (HLOClient_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction), [(AreBroadcastCompatible $l, $r)]>; def : EqualityPat; @@ -273,6 +266,13 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), (HLO_CrossReplicaSumOp $input, (CastElementsToI64Elements $group_assignment))>; +//===----------------------------------------------------------------------===// +// All2All op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), + (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; + //===----------------------------------------------------------------------===// // FFT op patterns. //===----------------------------------------------------------------------===// @@ -393,39 +393,36 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_ (HLO_SelectOp:$num_lower_or_m (HLO_CompareOp $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + HLO_COMPARISON_DIRECTION_LT ), $m_dim, $num_lower ), (HLO_SelectOp:$num_upper_or_n (HLO_CompareOp - $num_upper, $zero, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT ), $n_dim, $num_upper ), (HLO_SelectOp (HLO_AndOp - (HLO_CompareOp + (HLOClient_BroadcastCompareOp (HLO_NegOp (createConvertOp $op, $num_lower_or_m, $input) ), (HLO_SubOp:$offset - (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input), - (NullDenseIntElementsAttr) + (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input) ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE ), - (HLO_CompareOp + (HLOClient_BroadcastCompareOp $offset, (createConvertOp $op, $num_upper_or_n, $input ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE - ), - (BinBroadcastDimensions $offset, $input) + ) ), $input, (HLO_ConstOp (ConstantSplat<"0"> $input)) @@ -449,8 +446,9 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), // TODO(hinsu): Lower unsigned and quantized types after supporting // them in GetScalarOfType. def : Pat<(TF_ReluOp AnyRankedTensor:$input), - (HLO_MaxOp (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, - (BinBroadcastDimensions $zero, $input)), + (HLOClient_BroadcastMaxOp + (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, + (BinBroadcastDimensions $zero, $input)), [(TF_SintOrFpTensor $input)]>; // TODO(hinsu): Lower unsigned and quantized types after supporting @@ -472,7 +470,7 @@ def : Pat<(TF_Relu6Op AnyRankedTensor:$input), // to create splat tensor of dynamic shape in HLO. def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), (HLO_SelectOp - (HLO_CompareOp $features, + (HLOClient_BroadcastCompareOp $features, (HLO_ConstOp (GetScalarOfType<0> $features)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT), $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>; @@ -514,16 +512,14 @@ foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { } //===----------------------------------------------------------------------===// -// Ternary op patterns. +// Reverse op patterns. //===----------------------------------------------------------------------===// -def BothTypesMatch : Constraint, - "types must be equal">; +// Handles axis conversion for TF reverse. +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">; -def : Pat<(TF_SelectOp $cond, $t, $e), (HLO_SelectOp $cond, $t, $e), - // TODO(jpienaar): This restriction is to avoid creating a currently - // unsupported HLO select. - [(BothTypesMatch $t, $e)]>; +def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), + (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; //===----------------------------------------------------------------------===// // Unary op patterns. @@ -543,7 +539,6 @@ foreach Mapping = [ [TF_LogicalNotOp, HLO_NotOp], [TF_NegOp, HLO_NegOp], [TF_RealOp, HLO_RealOp], - [TF_RoundOp, HLO_RoundOp], [TF_RsqrtOp, HLO_RsqrtOp], [TF_SinOp, HLO_SinOp], [TF_SqrtOp, HLO_SqrtOp], @@ -576,7 +571,6 @@ def : Pat<(TF_SignOp $x), (HLO_CompareOp $x, $x, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_NE ), (HLO_ConstOp (ConstantSplat<"0"> $x)), @@ -617,10 +611,10 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2), //===----------------------------------------------------------------------===// // Sigmoid grad op. //===----------------------------------------------------------------------===// + +// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the +// shape of $l instead of having it as a constant. def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_MulOp - (HLO_MulOp $r, $l, (NullDenseIntElementsAttr)), - (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l, - (NullDenseIntElementsAttr)), - (NullDenseIntElementsAttr)), - [(IEEEFloatTensor $l)]>; + (HLO_MulOp $r, $l), + (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 551462572f1..b15974979c9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -83,31 +83,52 @@ static bool IsOpWhitelisted(Operation* op) { // clang-format off static llvm::SmallDenseSet ops = { TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), - TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -116,21 +137,43 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get() + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get() }; // clang-format on diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index c0f6c2c3541..21e39db018b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -36,47 +36,36 @@ def IsSameSizePred : CPred< def IsSameSizeConstraint : Constraint; -def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r), (AndOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), (AddFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r), (SubFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), (MulFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), (DivFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), (RemFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), (AddIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SubIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), (MulIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedDivIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedRemIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index 43c0911a4a6..ddbb672c70a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -57,8 +57,9 @@ class LhloFuseLinalg : public PassWrapper { for (auto func_arg : func.getArguments()) { func_args.insert(func_arg); } + MLIRContext* ctx = func.getContext(); OpBuilder b(func); - OperationFolder folder(func.getContext()); + OperationFolder folder(ctx); func.walk([&](linalg::GenericOp generic_op) { SmallVector tile_sizes(tile_sizes_.begin(), tile_sizes_.end()); @@ -68,12 +69,14 @@ class LhloFuseLinalg : public PassWrapper { auto op = cast(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) { if (!func_args.count(result)) continue; - if (tileGenericOp(op, tile_sizes, &b, &folder)) { + if (tileGenericOp(op, tile_sizes, &b)) { generic_op.erase(); return; } } }); + auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); + applyPatternsAndFoldGreedily(func, patterns); // Fuse producers of tiled linalg ops. llvm::SmallDenseSet erase_set; @@ -92,19 +95,22 @@ class LhloFuseLinalg : public PassWrapper { *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); } } + + auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); + applyPatternsAndFoldGreedily(func, patterns); } for (auto* e : erase_set) e->erase(); } private: - bool tileGenericOp(LinalgOp op, ArrayRef tile_sizes, OpBuilder* b, - OperationFolder* folder) { - auto tiled_generic_op = - use_parallel_loops_ - ? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes, - /*permutation=*/{}, folder) - : linalg::tileLinalgOp(*b, op, tile_sizes, - /*permutation=*/{}, folder); + bool tileGenericOp(LinalgOp op, ArrayRef tile_sizes, OpBuilder* b) { + auto loopType = use_parallel_loops_ + ? linalg::LinalgTilingLoopType::ParallelLoops + : linalg::LinalgTilingLoopType::Loops; + auto tiled_generic_op = linalg::tileLinalgOp(*b, op, + linalg::LinalgTilingOptions() + .setTileSizes(tile_sizes) + .setLoopType(loopType)); return tiled_generic_op.hasValue(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 2921a49ba70..f7f5537f882 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -54,14 +54,14 @@ struct BinaryOpConverter : public OpRewritePattern { induction_vars.push_back(forOp.getInductionVar()); rewriter.setInsertionPointToStart(forOp.getBody()); } - auto l = rewriter.create(loc, lhs, induction_vars); - auto r = rewriter.create(loc, rhs, induction_vars); + auto l = rewriter.create(loc, lhs, induction_vars); + auto r = rewriter.create(loc, rhs, induction_vars); Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( op, element_type, {l, r}, &rewriter); if (opResult == nullptr) { return failure(); } - rewriter.create(loc, opResult, op.out(), induction_vars); + rewriter.create(loc, opResult, op.out(), induction_vars); rewriter.eraseOp(op); return success(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index e6f3ac02d4f..f0eb3cc1a0f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project @@ -112,7 +112,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { auto step = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - auto loop = rewriter.create(loc, zero, upper, step); + auto loop = rewriter.create(loc, zero, upper, step); rewriter.setInsertionPointToStart(loop.getBody()); // Compute memrefs for the value to reduce. This makes it easier to just @@ -173,8 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>(); target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc new file mode 100644 index 00000000000..385e0859906 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +struct StaticMemRefCastOpConverter + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto cast_op = cast(op); + + StaticMemRefCastOpOperandAdaptor operands_adaptor(operands); + MemRefDescriptor sourceMemRef(operands_adaptor.operand()); + + MemRefType targetMemRefType = + cast_op.getResult().getType().cast(); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + // Create descriptor. + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); + // Set allocated ptr. + Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); + allocated = + rewriter.create(loc, llvmTargetElementTy, allocated); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value ptr = sourceMemRef.alignedPtr(rewriter, loc); + ptr = rewriter.create(loc, llvmTargetElementTy, ptr); + desc.setAlignedPtr(rewriter, loc, ptr); + + // Fill size and stride descriptors in memref. + auto target_sizes = targetMemRefType.getShape(); + int64_t target_offset; + llvm::SmallVector target_strides; + if (failed((getStridesAndOffset(targetMemRefType, target_strides, + target_offset)))) + return failure(); + + // Copy offset of `targetMemRef`. + desc.setConstantOffset(rewriter, loc, target_offset); + for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + desc.setConstantSize(rewriter, loc, i, target_sizes[i]); + desc.setConstantStride(rewriter, loc, i, target_strides[i]); + } + rewriter.replaceOp(op, {desc}); + return success(); + } +}; + +struct DynamicMemRefCastOpConverter + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto cast_op = cast(op); + + DynamicMemRefCastOpOperandAdaptor operands_adaptor(operands); + MemRefDescriptor sourceMemRef(operands_adaptor.operand()); + + MemRefType targetMemRefType = + cast_op.getResult().getType().cast(); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + // Create descriptor. + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); + // Set allocated ptr. + Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); + allocated = + rewriter.create(loc, llvmTargetElementTy, allocated); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value ptr = sourceMemRef.alignedPtr(rewriter, loc); + ptr = rewriter.create(loc, llvmTargetElementTy, ptr); + desc.setAlignedPtr(rewriter, loc, ptr); + // Copy offset of `sourceMemRef`. + desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc)); + + // Fill size and stride descriptors in memref. + if (!cast_op.sizes().empty()) { + auto sizes = operands_adaptor.sizes(); + auto strides = operands_adaptor.strides(); + for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + desc.setSize(rewriter, loc, i, sizes[i]); + desc.setStride(rewriter, loc, i, strides[i]); + } + } + rewriter.replaceOp(op, {desc}); + return success(); + } +}; + +} // namespace + +void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, + OwningRewritePatternList *patterns) { + patterns->insert( + *converter); +} + +} // namespace xla_lhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm_pass.cc new file mode 100644 index 00000000000..9b809049290 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm_pass.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +class TestLhloToLLVMPass + : public ::mlir::PassWrapper> { + public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + OwningRewritePatternList patterns; + LLVMTypeConverter converter(m.getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + PopulateLhloToLLVMConversionPatterns(&converter, &patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalDialect(); + + if (failed(applyFullConversion(m, target, patterns, &converter))) { + signalPassFailure(); + } + } +}; + +} // namespace + +static PassRegistration legalize_lhlo_pass( + "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 54b3acd3787..734a75a4307 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -61,15 +61,15 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands, // Converts a block with LHLO ops and with signature: // ^bb(%lhs: memref, %rhs: memref, %res: memref): -// into a reduction operator of loop.reduce by doing buffer allocation for -// scalar arguments and the result of `loop.reduce` to make it compatible with +// into a reduction operator of scf.reduce by doing buffer allocation for +// scalar arguments and the result of `scf.reduce` to make it compatible with // LHLO ops. -void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, +void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op, Block* lhlo_block, OpBuilder* b) { Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); OpBuilder::InsertionGuard guard(*b); b->setInsertionPointToStart(&loop_reduce_op_body); - b->create( + b->create( loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(), lhlo_block, b)); } @@ -136,9 +136,9 @@ MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, return mapped_ivs; } -// Returns loop::Parallel over a shaped value with static or dynamic shape. -loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, - OpBuilder* b) { +// Returns scf::Parallel over a shaped value with static or dynamic shape. +scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, + OpBuilder* b) { Value zero = b->create(loc, 0); Value one = b->create(loc, 1); @@ -151,10 +151,10 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, lower.push_back(zero); step.push_back(one); } - return b->create(loc, lower, upper, step); + return b->create(loc, lower, upper, step); } -// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. +// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // contains the reduction operator. @@ -170,10 +170,10 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // is roughly converted into: // // %init = load %init_buf[] : memref -// loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { -// %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { +// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { +// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> -// loop.reduce(%elem_to_reduce) { +// scf.reduce(%elem_to_reduce) { // ^bb0(%elem: f32, %acc: f32): // no predecessors // elem_buf = alloc() : memref // store %elem, elem_buf[] : memref @@ -181,11 +181,11 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // store %acc, acc_buf[] : memref // // %acc_result = load acc_buf[] : memref -// loop.reduce.return %acc_result : f32 +// scf.reduce.return %acc_result : f32 // } : f32 -// loop.yield +// scf.yield // } : f32 -// loop.yield +// scf.yield // } class ReduceOpConverter : public OpConversionPattern { public: @@ -197,7 +197,7 @@ class ReduceOpConverter : public OpConversionPattern { // TODO(b/137624192) Implement variadic reduce. if (xla_reduce_op.out().size() != 1) return failure(); - loop::ReduceOp reduce_op = + scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, &xla_reduce_op.body().front(), &rewriter); @@ -206,26 +206,26 @@ class ReduceOpConverter : public OpConversionPattern { } private: - // Creates nested `loop.parallel` ops with `loop.reduce`. The outer ParallelOp + // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp // refers to the parallel dimensions of `xla_reduce_op` if any and the inner - // ParallelOp refers to the reduction dimensions. The loop.reduce op is + // ParallelOp refers to the reduction dimensions. The scf.reduce op is // returned. // // If the reduction argument is a memref<100x10x5xf32> and the // reduction is performed along dimension 1 then this method will generate // // %init = load %init_buf[] : memref - // loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { - // %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { + // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { + // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> - // loop.reduce(%elem_to_reduce) { + // scf.reduce(%elem_to_reduce) { // // } : f32 - // loop.yield + // scf.yield // } : f32 - // loop.yield + // scf.yield // } - loop::ReduceOp CreateReduceOpInNestedParallelLoops( + scf::ReduceOp CreateReduceOpInNestedParallelLoops( xla_lhlo::ReduceOp xla_reduce_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); @@ -254,13 +254,13 @@ class ReduceOpConverter : public OpConversionPattern { SmallVector init_value = { rewriter->create(loc, *xla_reduce_op.init_values().begin())}; // Outer ParallelOp is not needed if it is a reduction across all dims. - loop::ParallelOp outer; + scf::ParallelOp outer; if (!parallel_lower.empty()) { - outer = rewriter->create(loc, parallel_lower, - parallel_upper, parallel_step); + outer = rewriter->create(loc, parallel_lower, + parallel_upper, parallel_step); rewriter->setInsertionPointToStart(outer.getBody()); } - loop::ParallelOp inner = rewriter->create( + scf::ParallelOp inner = rewriter->create( loc, reduce_lower, reduce_upper, reduce_step, init_value); Value reduction_result = *inner.getResults().begin(); @@ -294,7 +294,7 @@ class ReduceOpConverter : public OpConversionPattern { rewriter->setInsertionPointToStart(inner.getBody()); Value elem = rewriter->create( loc, *xla_reduce_op.operands().begin(), indices); - return rewriter->create(loc, elem); + return rewriter->create(loc, elem); } }; @@ -314,8 +314,8 @@ class ReduceOpConverter : public OpConversionPattern { // accumulator = reduction_operator(output[O], value) // output[O] = accumulator // -// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a -// loop::ReduceOp. +// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a +// scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops that traverese output // buffer. The inner `ParalleOp` refers to the reduction loops that traverse // reduction windows and `ReduceOp` contains the reduction operator. @@ -341,20 +341,20 @@ class ReduceOpConverter : public OpConversionPattern { // is roughly converted into: // // %neutral_elem = load %init_buf[] : memref -// loop.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { -// %result = loop.parallel (%iw, %jw) = (%c0, %c0) +// scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { +// %result = scf.parallel (%iw, %jw) = (%c0, %c0) // to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 { // %in_bounds = // %elem = load %operand[%computed_i, %computed_j] // %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32 -// loop.reduce(%elem_to_reduce) : f32 { +// scf.reduce(%elem_to_reduce) : f32 { // ^bb0(%arg7: f32, %arg8: f32): // // } -// loop.yield +// scf.yield // } // store %result, %output_buffer[%i, %j] : memref<56x56xf32> -// loop.yield +// scf.yield // } // return // } @@ -366,12 +366,12 @@ class ReduceWindowOpConverter LogicalResult matchAndRewrite( xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { - loop::ParallelOp output_loop, window_loop; + scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, &rewriter); - loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( + scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( xla_reduce_window_op, output_loop, window_loop, &rewriter); ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, @@ -381,7 +381,7 @@ class ReduceWindowOpConverter } private: - std::pair + std::pair CreateParallelLoopsToTraverseOutputAndWindow( xla_lhlo::ReduceWindowOp xla_reduce_window_op, ConversionPatternRewriter* rewriter) const { @@ -405,7 +405,7 @@ class ReduceWindowOpConverter window_upper.push_back( rewriter->create(loc, window_dim.getSExtValue())); } - auto window_loop = rewriter->create( + auto window_loop = rewriter->create( loc, window_lower, window_upper, window_step, init_value); Value reduction_result = *window_loop.getResults().begin(); @@ -414,9 +414,9 @@ class ReduceWindowOpConverter return std::make_pair(output_loop, window_loop); } - loop::ReduceOp CreateReduceOpInNestedParallelLoops( + scf::ReduceOp CreateReduceOpInNestedParallelLoops( xla_lhlo::ReduceWindowOp xla_reduce_window_op, - loop::ParallelOp output_loop, loop::ParallelOp window_loop, + scf::ParallelOp output_loop, scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { rewriter->setInsertionPointToStart(window_loop.getBody()); auto loc = xla_reduce_window_op.getLoc(); @@ -436,20 +436,20 @@ class ReduceWindowOpConverter xla_reduce_window_op, output_loop.getInductionVars(), window_loop.getInductionVars(), rewriter); - auto elem_or_init = rewriter->create( + auto elem_or_init = rewriter->create( loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, /*withElseRegion=*/true); OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); Value elem = then_builder.create( loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); - then_builder.create(loc, elem); + then_builder.create(loc, elem); OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); - else_builder.create(loc, *window_loop.initVals().begin()); + else_builder.create(loc, *window_loop.initVals().begin()); - return rewriter->create(loc, - *elem_or_init.results().begin()); + return rewriter->create(loc, + *elem_or_init.results().begin()); } }; @@ -457,16 +457,16 @@ class ReduceWindowOpConverter // https://www.tensorflow.org/xla/operation_semantics#selectandscatter // // Pseudocode: -// loop.parallel(coordinates O in the output): +// scf.parallel(coordinates O in the output): // output[O] = init -// loop.parallel(coordinates S in the source): +// scf.parallel(coordinates S in the source): // selected_ivs = 0 // selected_val = 0 // initialized_flag = false -// loop.for (first dim W_1 in the window) +// scf.for (first dim W_1 in the window) // iter_args (selected_ivs, selected_val, initialized_flag): // ... -// loop.for (last dim W_N in the window): +// scf.for (last dim W_N in the window): // iter_args (selected_ivs, selected_val, initialized_flag): // I = S * stride + W - pad_low // if I within bounds of operand: @@ -490,7 +490,7 @@ class SelectAndScatterOpConverter ConversionPatternRewriter& rewriter) const final { auto loc = s_and_s_op.getLoc(); InitializeOutput(s_and_s_op, &rewriter); - loop::ParallelOp loop_over_src = + scf::ParallelOp loop_over_src = MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter); rewriter.setInsertionPointToStart(loop_over_src.getBody()); @@ -520,7 +520,7 @@ class SelectAndScatterOpConverter auto loc = s_and_s_op.getLoc(); Value init_value = b->create(loc, s_and_s_op.init_value()); - loop::ParallelOp loop_over_output = + scf::ParallelOp loop_over_output = MakeLoopOverShape(loc, s_and_s_op.out(), b); OpBuilder::InsertionGuard guard(*b); b->setInsertionPointToStart(loop_over_output.getBody()); @@ -531,10 +531,10 @@ class SelectAndScatterOpConverter struct WindowLoops { SmallVector selected_ivs; SmallVector window_ivs; - loop::ForOp inner_loop; + scf::ForOp inner_loop; }; WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, - loop::ParallelOp loop_over_src, + scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value zero = b->create(loc, 0); @@ -558,12 +558,12 @@ class SelectAndScatterOpConverter s_and_s_op.window_dimensions()->getIntValues()) { Value upper = b->create(loc, window_dim.getSExtValue()); result.inner_loop = - b->create(loc, zero, upper, one, iter_args); + b->create(loc, zero, upper, one, iter_args); if (b->getInsertionBlock() == loop_over_src.getBody()) { ip = b->saveInsertionPoint(); result.selected_ivs = result.inner_loop.getResults().take_front(rank); } else { - b->create(loc, result.inner_loop.getResults()); + b->create(loc, result.inner_loop.getResults()); } b->setInsertionPointToStart(result.inner_loop.getBody()); iter_args = ValueRange{result.inner_loop.getRegionIterArgs()}; @@ -599,7 +599,7 @@ class SelectAndScatterOpConverter }; SmallVector SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, - loop::ParallelOp loop_over_src, + scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -614,7 +614,7 @@ class SelectAndScatterOpConverter IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs()); - auto if_in_bounds = inner_loop_b.create( + auto if_in_bounds = inner_loop_b.create( loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds, /*withElseRegion=*/true); @@ -623,16 +623,16 @@ class SelectAndScatterOpConverter OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder(); auto select_or_init_results = SelectOrInitialize( s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b); - in_bounds_then_b.create(loc, select_or_init_results); + in_bounds_then_b.create(loc, select_or_init_results); } // Case when we are in the pad. { OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder(); - in_bounds_else_b.create(loc, ivs_val_flag.to_vector()); + in_bounds_else_b.create(loc, ivs_val_flag.to_vector()); } - inner_loop_b.create(loc, if_in_bounds.getResults()); + inner_loop_b.create(loc, if_in_bounds.getResults()); return window_loops.selected_ivs; } @@ -647,8 +647,8 @@ class SelectAndScatterOpConverter Value operand_elem = b->create(loc, s_and_s_op.operand(), operand_ivs); auto if_init = - b->create(loc, iter_arg_types, ivs_val_flag->is_init(), - /*withElseRegion=*/true); + b->create(loc, iter_arg_types, ivs_val_flag->is_init(), + /*withElseRegion=*/true); // Init == true, i.e. iter args are already initialized with a selected // element in boundaries of the operand. Select function has to be computed // here. @@ -660,32 +660,31 @@ class SelectAndScatterOpConverter ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()}, &lhlo_select, &if_init_then_b); - auto if_pred = - if_init_then_b.create(loc, iter_arg_types, pred, - /*withElseRegion=*/true); + auto if_pred = if_init_then_b.create(loc, iter_arg_types, pred, + /*withElseRegion=*/true); // Pred == true, therefore pack newly selected ivs, val and init flag back // to iter_args and return. { OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(); - if_pred_then_b.create( + if_pred_then_b.create( loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); } // Pred == false, therefore return old iter_args. { OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(); - if_pred_else_b.create(loc, ivs_val_flag->to_vector()); + if_pred_else_b.create(loc, ivs_val_flag->to_vector()); } - if_init_then_b.create(loc, if_pred.getResults()); + if_init_then_b.create(loc, if_pred.getResults()); } // Init == false, i.e. only pad was visited before and this is the first // element in the boundaries of the operand. { OpBuilder if_init_else_b = if_init.getElseBodyBuilder(); - if_init_else_b.create( + if_init_else_b.create( loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); } return if_init.getResults(); @@ -708,7 +707,7 @@ struct LhloLegalizeToParallelLoops ConversionTarget target(getContext()); target.addLegalDialect(); + scf::SCFDialect, XlaLhloDialect>(); target.addIllegalOp(); diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td index dcb0ab20e9e..e1ae5ef6abf 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td @@ -28,70 +28,62 @@ include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" // and imaginary components. foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), - $broadcast_dimensions), - (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), - $broadcast_dimensions))>; + (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)), + (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>; // Complex multiplication results in a cross product multiplication between the // real and imaginary components such that: // result.real = lhs.real * rhs.real - lhs.imag * rhs.imag // result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp (HLO_SubOp (HLO_MulOp (HLO_RealOp:$lhs_real $lhs), - (HLO_RealOp:$rhs_real $rhs), - $broadcast_dimensions), + (HLO_RealOp:$rhs_real $rhs)), (HLO_MulOp (HLO_ImagOp:$lhs_imag $lhs), - (HLO_ImagOp:$rhs_imag $rhs), - $broadcast_dimensions), - (NullDenseIntElementsAttr)), + (HLO_ImagOp:$rhs_imag $rhs))), (HLO_AddOp - (HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions), - (HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions), - (NullDenseIntElementsAttr)))>; + (HLO_MulOp $lhs_real, $rhs_imag), + (HLO_MulOp $lhs_imag, $rhs_real)))>; // Multiplication between a complex and real tensor can be distributed by // applying the real multiplicant to both the real and complex component. // // Note that the sourcep pattern is not legal according to the HLO dialect but // instead handle intermediates generated by other patterns. -def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_MulOp (HLO_RealOp $lhs), $rhs), + (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>; -def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions), - (HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>; + (HLO_MulOp $lhs, (HLO_RealOp $rhs)), + (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>; // Division is performed by normalizing the denominator by multiplying by the // conjugate of the rhs. // numerator = lhs * conj(rhs) // denominator = rhs * conj(rhs) -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_DivOp (HLO_MulOp:$num $lhs, (HLO_ComplexOp:$conj (HLO_RealOp $rhs), - (HLO_NegOp (HLO_ImagOp $rhs))), - $broadcast_dimensions), - (HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)), - (BinBroadcastDimensions $num, $den))>; + (HLO_NegOp (HLO_ImagOp $rhs)))), + (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>; -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_DivOp (HLO_RealOp $lhs), $rhs), + (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>; // Absolute value is evaluated as: @@ -100,11 +92,8 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), (HLO_ComplexOp (HLO_SqrtOp (HLO_AddOp - (HLO_MulOp (HLO_RealOp:$real $val), $real, - (NullDenseIntElementsAttr)), - (HLO_MulOp (HLO_ImagOp:$imag $val), $imag, - (NullDenseIntElementsAttr)), - (NullDenseIntElementsAttr))), + (HLO_MulOp (HLO_RealOp:$real $val), $real), + (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))), (HLO_ConstOp (ConstantSplat<"0"> $real)))>; // Exponential can be lowered to an exponential on the real component and a @@ -117,5 +106,4 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), (HLO_ExpOp (HLO_RealOp $val)), (HLO_ComplexOp (HLO_CosOp (HLO_ImagOp:$imag $val)), - (HLO_SinOp $imag)), - (NullDenseIntElementsAttr))>; + (HLO_SinOp $imag)))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h index fed21e9bafc..21b954a3eb4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -49,6 +49,7 @@ MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CosOp); MAP_HLO_TO_LHLO(DivOp); +MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index dceb73efb33..c317dc36b3c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -227,6 +227,28 @@ inline Value MapLhloOpToStdScalarOp( loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, @@ -259,11 +281,9 @@ inline Value MapLhloOpToStdScalarOp( // No conversion is needed for the same width integers return args.front(); } - // TODO(dfki-ehna): Add other primitive type conversions - // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) { - // return b.create(loc, result_types, - // args,mlir::None); - // } + if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { + return b->create(loc, result_types, args, mlir::None); + } return nullptr; } diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index a4ffa57957e..c56f5adc12d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -28,272 +28,6 @@ namespace xla_hlo { namespace { -// Returns a 1-d i64 elements attribute populated with numbers from start to -// end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { - int size = end - start; - - SmallVector vals; - vals.resize(size); - std::iota(vals.begin(), vals.end(), start); - - TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); -} - -// Helper function for OpRewritePattern classes to materialize broadcasts on -// LHS and RHS arguments to a binary op. -// -// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful, -// returns false otherwise. -template -bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - if (!op.broadcast_dimensions().hasValue()) { - // Note: the op may still have an implicit broadcast on it, such as - // for (tensor<1xf32>, tensor<4xf32>). - return false; - } - - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - // - // If the higher dimensional argument does not actually need the broadcast, - // a canonicalization pass should be able to remove that op later. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto op_ranked_type = op.getType().template dyn_cast(); - auto lhs_ranked_type = lhs.getType().dyn_cast(); - auto rhs_ranked_type = rhs.getType().dyn_cast(); - if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - // Dynamic result shape, can't use BroadcastInDimOp. - assert(op_ranked_type.hasStaticShape() && - "dynamic shape requires DynamicBroadcastInDim"); - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - - // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. - // Use the original op.broadcast_dimensions for the lower rank arg. - auto higher_rank_broadcast_dims = - GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); - DenseIntElementsAttr lhs_broadcast_dims; - DenseIntElementsAttr rhs_broadcast_dims; - if (lhs_rank > rhs_rank) { - lhs_broadcast_dims = higher_rank_broadcast_dims; - rhs_broadcast_dims = op.broadcast_dimensions().getValue(); - } else if (lhs_rank < rhs_rank) { - lhs_broadcast_dims = op.broadcast_dimensions().getValue(); - rhs_broadcast_dims = higher_rank_broadcast_dims; - } else { - // This shouldn't happen for legal ops. If the broadcast_dimensions - // attribute is set, the ranks should be different. - // TODO(scotttodd): Add a custom verification for ops and assert here. - return false; - } - - // BroadcastInDimOp must have the same element type for operands and results, - // so preserve the original output shape and the original input element type. - // For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`: - // broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32> - // broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32> - // SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - ArrayRef op_shape = op_ranked_type.getShape(); - auto lhs_type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - auto rhs_type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - - *out_lhs = rewriter->createOrFold(op.getLoc(), lhs_type, - lhs, lhs_broadcast_dims); - *out_rhs = rewriter->createOrFold(op.getLoc(), rhs_type, - rhs, rhs_broadcast_dims); - return true; -} - -// Helper template to generate code for computing the result shape of a -// broadcasted operation. This ultimately should be subsumed by functions -// from the shape dialect. -// Assumes that large and small are the operand values of `op` and that they -// have a ranked tensory type with rank(large) >= rank(small). -template -std::vector ComputeBroadcastedShape(SrcOp op, Value small, Value large, - PatternRewriter *rewriter) { - auto loc = op.getLoc(); - auto larger_ranked_type = large.getType().cast(); - auto output_rank = larger_ranked_type.getRank(); - - constexpr int kExpandShape = -1; - - std::vector shape_values; - shape_values.reserve(output_rank); - std::vector indexes(output_rank, kExpandShape); - DenseIntElementsAttr broadcast_dimensions = - op.broadcast_dimensions().getValue(); - // Compute a mapping from output dimensions to their corresponding input - // dimensions in the smaller ranked operand. - for (auto pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - indexes.at(pair.value().getLimitedValue()) = pair.index(); - } - - // Compute the broadcasted shape of the result using numpy style broadcasting - // semantics. The result shape at a position is the shape of the larger - // operand at that position if the no dimension of the smaller operand is - // mapped to it. - // If both operands contribute to an output dimension, their shape has to - // either be the same in that dimension or it can be 1, in which case the - // shape of the other operand is used. - for (int i = 0; i < output_rank; ++i) { - Value index_value; - if (indexes[i] == kExpandShape) { - // The smaller shape gets expanded to the larger one in this case. - index_value = rewriter->create(loc, large, i); - } else { - // Compute the result shape depending on whether the rank of smaller is 1. - // This does not check that the broadcast operation actualy is correct. - // In particular, we do not check that both shapes are the same if the - // smaller ranked shape is not 1. - ConstantOp one = rewriter->create( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), 1)); - DimOp lrg_dim = rewriter->create(loc, large, i); - DimOp sml_dim = rewriter->create(loc, small, indexes[i]); - CmpIOp compare = - rewriter->create(loc, CmpIPredicate::eq, lrg_dim, one); - index_value = - rewriter->create(loc, compare, lrg_dim, sml_dim); - } - // Ideally, we would like to keep this on index but MLIR does not allow - // this. - shape_values.push_back(rewriter->create( - loc, index_value, rewriter->getIntegerType(32))); - } - - return shape_values; -} - -// Helper function for OpRewritePattern classes to materialize dynamic -// broadcasts on LHS and RHS arguments to a binary op. -// -// Returns true and set out_lhs and out_rhs for materialized dynamic broadcasts -// for LHS and RHS arguments, else returns false. -template -bool CreateDynamicBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - if (!op.broadcast_dimensions().hasValue()) { - // Note: the op may still have an implicit broadcast on it, such as - // for (tensor<1xf32>, tensor<4xf32>). - return false; - } - - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto lhs_ranked_type = lhs.getType().dyn_cast(); - auto rhs_ranked_type = rhs.getType().dyn_cast(); - if (!lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - - // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. - // Use the original op.broadcast_dimensions for the lower rank arg. - auto higher_rank_broadcast_dims = - GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); - DenseIntElementsAttr lhs_broadcast_dims; - DenseIntElementsAttr rhs_broadcast_dims; - std::vector shape_elements; - if (lhs_rank > rhs_rank) { - lhs_broadcast_dims = higher_rank_broadcast_dims; - rhs_broadcast_dims = op.broadcast_dimensions().getValue(); - shape_elements = ComputeBroadcastedShape(op, rhs, lhs, rewriter); - } else if (lhs_rank < rhs_rank) { - lhs_broadcast_dims = op.broadcast_dimensions().getValue(); - rhs_broadcast_dims = higher_rank_broadcast_dims; - shape_elements = ComputeBroadcastedShape(op, lhs, rhs, rewriter); - } else { - // This shouldn't happen for legal ops. If the broadcast_dimensions - // attribute is set, the ranks should be different. - // TODO(scotttodd): Add a custom verification for ops and assert here. - return false; - } - - // DynamicBroadcastInDimOp preserves the element type but produces a tensor - // with unranked shape. The rank of the output is the length of the - // output shape argument. - SmallVector op_shape(shape_elements.size(), - RankedTensorType::kDynamicSize); - auto lhs_type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - auto rhs_type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - - // We need a way to turn a list of scalars into a vector. While Standard - // dialect does not have one, use the XLA_HLO variant. - int shape_size = shape_elements.size(); - Type shape_element_type = shape_elements.front().getType(); - Value shape_value = rewriter->create( - op.getLoc(), RankedTensorType::get({shape_size}, shape_element_type), - shape_elements); - - *out_lhs = rewriter->createOrFold( - op.getLoc(), lhs_type, lhs, shape_value, lhs_broadcast_dims); - *out_rhs = rewriter->createOrFold( - op.getLoc(), rhs_type, rhs, shape_value, rhs_broadcast_dims); - return true; -} - -template -bool CreateBroadcastForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - auto op_ranked_type = op.getType().template dyn_cast(); - if (!op_ranked_type) return false; - - if (op_ranked_type.hasStaticShape()) { - if (!CreateStaticBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) { - return false; - } - } else { - if (!CreateDynamicBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) { - return false; - } - } - return true; -} - -template -struct BinaryOpWithBroadcastConvert : public OpRewritePattern { - explicit BinaryOpWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; - - if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) - return failure(); - - // Replace the original op with a new one that uses the new args. - // New args are broadcasts, so no dims are needed on the replacement op. - rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr); - return success(); - } -}; - // Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays // must be the same shape. Alternatively, as a restricted form of broadcasting, // min and/or max can be a scalar of type T." @@ -335,57 +69,10 @@ struct ClampWithBroadcastConvert : public OpRewritePattern { } }; -// Specialized class for CompareOp, as it has an additional builder argument. -struct CompareWithBroadcastConvert : public OpRewritePattern { - explicit CompareWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; - - if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr, - op.comparison_direction()); - return success(); - } -}; - } // namespace void SetupMaterializeBroadcastsLegality(MLIRContext *context, ConversionTarget *conversionTarget) { -#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \ - conversionTarget->addDynamicallyLegalOp( \ - [](OpType op) { return !op.broadcast_dimensions().hasValue(); }); - // Binary elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp); - - // Binary logical elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp); - - // CompareOp. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp); - -#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST - conversionTarget->addDynamicallyLegalOp([](ClampOp op) { return op.max().getType() == op.operand().getType() && op.min().getType() == op.operand().getType(); @@ -394,30 +81,10 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context, void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - // Binary elementwise ops. - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>( - context); - patterns->insert>(context); - patterns->insert>(context); - - // Binary logical elementwise ops. - patterns->insert>(context); - patterns->insert>(context); - patterns->insert>(context); - - // ClampOp. It can have a restricted form of broadcasting. + // ClampOp. This op has a special case where it accepts either same-shaped + // inputs or scalars (a restricted form of broadcasting). This makes the + // broadcast explicit. patterns->insert(context); - // CompareOp. Note the specialized class instead of using the template. - patterns->insert(context); } } // namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 2d0164981a3..e3dd5380d7c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -36,7 +36,7 @@ namespace xla_hlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion = false); + bool allow_partial_conversion = false, bool legalize_chlo = true); /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the /// specified device type. @@ -50,7 +50,8 @@ std::unique_ptr> createLegalizeTFControlFlowPass(); /// dialect using the conversion patterns registered by the HLO dialect. When /// allow_partial_conversion is false, emits an error if there is any operation /// that can't be legalized. -LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false); +LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, + bool legalize_chlo = true); /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); @@ -65,6 +66,13 @@ std::unique_ptr> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); +// Transforms unranked HLO operations to ranked ones where possible. +std::unique_ptr> createTransformUnrankedHloPass(); + +// Sinks constants implicitly captured in control flow regions. This is +// necessary to export to XLA. +std::unique_ptr> createSinkConstantsToControlFlowPass(); + } // namespace xla_hlo namespace xla_lhlo { @@ -81,8 +89,8 @@ std::unique_ptr> createLegalizeToGpuPass(); // Fuses linalg ops obtained after LHLO lowering. To enable fusion, // operations are first tiled. // -// When 'use_parallel_loops' is set, the tiling will use loop.parallel -// operations. Otherwise, loop.for operations are used. +// When 'use_parallel_loops' is set, the tiling will use scf.parallel +// operations. Otherwise, scf.for operations are used. // // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg // operation has more dimensions than tile sizes provided, 1 is used as diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index ad81cda19b9..59347198fe4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -23,6 +23,9 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { +class LLVMTypeConverter; +class OwningRewritePatternList; +class BufferAssignmentPlacer; namespace xla_hlo { // Collection of rewrite patterns for lowering a general dot product. @@ -38,9 +41,9 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); // Collection of rewrite patterns for lowering of HLO to LHLO dialect. -void populateHLOToLHLOConversionPattern(MLIRContext *context, - OwningRewritePatternList *patterns); - +void populateHLOToLHLOConversionPattern( + MLIRContext *context, BufferAssignmentPlacer *bufferAssignment, + TypeConverter *converter, OwningRewritePatternList *patterns); // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, OwningRewritePatternList *patterns); @@ -54,6 +57,15 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context, void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, OwningRewritePatternList *patterns); +// Sets up legality definitions for element-wise operations on ranked tensors. +void SetupTransformUnrankedHloLegality(MLIRContext *context, + ConversionTarget *conversionTarget); + +// Populates a collection of rewrite patterns to realize element-wise operations +// on ranked tensors where possible. +void PopulateTransformUnrankedHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + // Populate a collection of conversion patterns for un-fusing // batch_norm_inference and batch_norm_training into constituent HLO ops. // TODO(laurenzo): Implement un-fusing of batch_norm_training. @@ -62,6 +74,14 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context, } // namespace xla_hlo +namespace xla_lhlo { + +/// Collect a set of patterns to convert from the LHLO dialect to LLVM. +void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, + OwningRewritePatternList *patterns); + +} // namespace xla_lhlo + namespace xla_chlo { // Populates a collection of conversion patterns for legalizing client-HLO to diff --git a/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc new file mode 100644 index 00000000000..5a45e0f3b18 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// A pass that sinks constants implicitly captured in control flow regions. This +// is necessary to export to XLA. +class SinkConstantsToControlFlow + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([](Operation* op) { + if (auto while_op = llvm::dyn_cast(op)) { + SinkToRegion(&while_op.body()); + SinkToRegion(&while_op.cond()); + } else if (auto if_op = llvm::dyn_cast(op)) { + SinkToRegion(&if_op.true_branch()); + SinkToRegion(&if_op.false_branch()); + } + }); + } + + private: + // Performs constant sinking into a region. + static void SinkToRegion(Region* region) { + llvm::DenseMap sunk_constant; + visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { + Value constant = use->get(); + auto const_op = dyn_cast_or_null(constant.getDefiningOp()); + if (!const_op) return; + auto map_entry = sunk_constant.try_emplace(constant, nullptr); + if (!map_entry.second) { + // This constant has already been cloned into the region, reuse it. + use->set(map_entry.first->getSecond().getResult()); + if (constant.use_empty()) const_op.erase(); + return; + } + if (constant.hasOneUse()) { + const_op.getOperation()->moveBefore(®ion->front().front()); + return; + } + map_entry.first->getSecond() = const_op.clone(); + region->front().getOperations().insert(region->front().begin(), + map_entry.first->getSecond()); + use->set(map_entry.first->getSecond().getResult()); + }); + } +}; + +static mlir::PassRegistration pass( + "xla-hlo-sink-constants-to-control-flow", + "Sink constants implicitly captured in control flow regions. This is " + "necessary to export to XLA."); + +} // anonymous namespace + +std::unique_ptr> createSinkConstantsToControlFlowPass() { + return std::make_unique(); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc index 8976bd5b7d2..71441656c08 100644 --- a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc @@ -38,7 +38,8 @@ struct InferReturnTypeComponentsPattern : public RewritePattern { SmallVector components; if (failed(defining_op_int.inferReturnTypeComponents( op->getContext(), op->getLoc(), defining_op->getOperands(), - defining_op->getAttrs(), defining_op->getRegions(), components))) { + defining_op->getAttrDictionary(), defining_op->getRegions(), + components))) { return failure(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc index 32d8b079c89..eead03404cb 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc @@ -58,13 +58,9 @@ Value CalculateShapeValue(Location loc, Value operand, int64_t rank = result_type.getRank(); shape_values.reserve(rank); for (int64_t i = 0; i < rank; ++i) { - auto index_value = rewriter.create(loc, operand, i); - shape_values.push_back(rewriter.create( - loc, index_value, rewriter.getIntegerType(32))); + shape_values.push_back(rewriter.create(loc, operand, i)); } - Type shape_element_type = shape_values.front().getType(); - return rewriter.create( - loc, RankedTensorType::get({rank}, shape_element_type), shape_values); + return rewriter.create(loc, shape_values); } Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr, @@ -137,8 +133,8 @@ class UnfuseBatchNormInferencePattern if (!epsilon) { return failure(); } - Value stddev = rewriter.create( - bn_op.getLoc(), bn_op.variance(), epsilon, /*broadcast_dims=*/nullptr); + Value stddev = rewriter.create(bn_op.getLoc(), + bn_op.variance(), epsilon); stddev = rewriter.create(bn_op.getLoc(), stddev); // Broadcast all terms. @@ -162,13 +158,13 @@ class UnfuseBatchNormInferencePattern // Compute: // scale * (input - mean) / stddev + offset Value result = rewriter.create( - bn_op.getLoc(), bn_op.operand(), broadcast_mean, nullptr); + bn_op.getLoc(), bn_op.operand(), broadcast_mean); result = rewriter.create(bn_op.getLoc(), result, - broadcast_scale, nullptr); + broadcast_scale); result = rewriter.create(bn_op.getLoc(), result, - broadcast_stddev, nullptr); - rewriter.replaceOpWithNewOp(bn_op, result, broadcast_offset, - nullptr); + broadcast_stddev); + rewriter.replaceOpWithNewOp(bn_op, result, + broadcast_offset); return success(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc index 436a3e701e1..a12bd9e7c1a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc @@ -251,17 +251,15 @@ Value LhloDialectEmitter::GetOrCreateView( // Create the view for this slice size, possible with an affine map to model // the offset. The result is cached in the slices_ map. - SmallVector offset_map; - if (slice.offset()) { - offset_map.push_back(AffineMap::get( - /*dimCount=*/1, /*symbolCount=*/0, - {getAffineDimExpr(0, builder_.getContext()) + slice.offset()}, - builder_.getContext())); - } - auto slice_type = MemRefType::get({slice.size()}, i8_type_, offset_map); + // The std.view result type does not carry the static offset: this is not + // useful information. Rather, the view op must have the static offset. + auto slice_type = MemRefType::get({slice.size()}, i8_type_, {}); - auto slice_view = builder_.create( - alloc_buffer.getLoc(), slice_type, alloc_buffer, /*operands=*/llvm::None); + Value byte_shift = + builder_.create(alloc_buffer.getLoc(), slice.offset()); + auto slice_view = + builder_.create(alloc_buffer.getLoc(), slice_type, alloc_buffer, + byte_shift, /*sizes=*/ArrayRef{}); slices_.insert({slice_key, slice_view}); return slice_view; } @@ -277,9 +275,12 @@ StatusOr LhloDialectEmitter::GetOrCreateView( Value slice_view = GetOrCreateView(out_slice); TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType( target_shape, builder_)); + Value byte_shift = + builder_.create(builder_.getUnknownLoc(), 0); if (slice_view.getType() != out_type) - slice_view = builder_.create(builder_.getUnknownLoc(), out_type, - slice_view, llvm::None); + slice_view = + builder_.create(builder_.getUnknownLoc(), out_type, slice_view, + byte_shift, /*sizes=*/ArrayRef{}); return slice_view; } diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 9cce6799288..2b496677d62 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -84,7 +84,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { emitError(loc, "lhlo to linalg conversion expects ranked args"); return failure(); } - if (!argType.getElementType().isSignlessIntOrFloat()) { + auto elemTy = argType.getElementType(); + if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa()) { return failure(); } @@ -572,6 +573,34 @@ class ConstConverter : public OpConversionPattern { } }; +// TODO(b/156787842): Support the lowering for dynamic shapes. +template +class ReverseConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter, OpTy, + isLHLO>::DataMovementOpConverter; + static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { + auto resultType = + getXLAOpResultType(op).template cast(); + auto nloops = resultType.getRank(); + SmallVector inputExprs; + inputExprs.reserve(nloops); + for (int i = 0; i < nloops; ++i) + inputExprs.push_back(b->getAffineDimExpr(i)); + for (auto dim : op.dimensions()) { + int i = dim.getZExtValue(); + if (resultType.isDynamicDim(i)) return {}; + int n = resultType.getShape()[i]; + inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; + } + return b->getAffineMapArrayAttr( + {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); + } +}; + class SliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -618,17 +647,20 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -638,6 +670,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, ReshapeAddRemoveDimConverter, + ReverseConverter, ScalarPointwiseToStandardConverter, SliceConverter >(context); @@ -716,16 +749,19 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -735,6 +771,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, ReshapeAddRemoveDimConverter, ReshapeOpConverter, + ReverseConverter, TransposeConverter>(context); } diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_transform_unranked_hlo.cc b/tensorflow/compiler/mlir/xla/transforms/xla_transform_unranked_hlo.cc new file mode 100644 index 00000000000..b2afc7c1026 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/xla_transform_unranked_hlo.cc @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include "absl/memory/memory.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace xla_hlo { +namespace { + +template +inline void AddLegalOpOnRankedTensor(ConversionTarget *conversionTarget) { + conversionTarget->addDynamicallyLegalOp([](OpTy op) { + return op.getOperand().getType().template cast().hasRank(); + }); +} + +template +struct UnaryElementwiseOpConversion : public OpRewritePattern { + explicit UnaryElementwiseOpConversion(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Don't apply conversion to ops with statically shaped operands. + Value operand = op.getOperand(); + auto operandTy = operand.getType().dyn_cast(); + if (operandTy.hasRank()) return failure(); + + // Generate IR to flatten the operand. + auto loc = op.getLoc(); + Value shape = rewriter.create(loc, operand); + Value numElements = rewriter.create( + loc, rewriter.getType(), shape); + Value numElementsAsIndex = rewriter.create( + loc, rewriter.getIndexType(), numElements); + Value flatShapeAsDimTensor = + rewriter.create(loc, numElementsAsIndex); + auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, + operandTy.getElementType()); + Value flatOperand = rewriter.create( + loc, flatTensorTy, operand, flatShapeAsDimTensor); + + // Generate IR for the actual operation. + Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); + + // Generate IR to restore the original shape. + auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()); + Value shapeAsExtentTensor = + rewriter.create(loc, extentTensorTy, shape); + Value result = rewriter.create( + loc, operandTy, flatResult, shapeAsExtentTensor); + rewriter.replaceOp(op, result); + + return success(); + } +}; + +struct TransformUnrankedHloPass + : public PassWrapper { + void runOnFunction() override { + ConversionTarget conversionTarget(getContext()); + OwningRewritePatternList conversionPatterns; + SetupTransformUnrankedHloLegality(&getContext(), &conversionTarget); + PopulateTransformUnrankedHloPatterns(&getContext(), &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) + return signalPassFailure(); + } +}; + +} // namespace + +void SetupTransformUnrankedHloLegality(MLIRContext *context, + ConversionTarget *conversionTarget) { + conversionTarget->addLegalDialect(); + + // Targeted operations are only legal when they operate on ranked tensors. + AddLegalOpOnRankedTensor(conversionTarget); +} + +void PopulateTransformUnrankedHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert>(context); +} + +std::unique_ptr> createTransformUnrankedHloPass() { + return absl::make_unique(); +} + +static PassRegistration transform_unranked_hlo_pass( + "transform-unranked-hlo", + "Realize element-wise operations on ranked tensors where possible"); + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index cd22b527444..ea4ba8dab6b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -470,6 +470,7 @@ tf_xla_py_test( name = "concat_ops_test", size = "medium", srcs = ["concat_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "many_xla_args", @@ -561,6 +562,7 @@ tf_xla_py_test( name = "dynamic_slice_ops_test", size = "small", srcs = ["dynamic_slice_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1342,6 +1344,7 @@ tf_xla_py_test( name = "ternary_ops_test", size = "medium", srcs = ["ternary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1384,6 +1387,7 @@ tf_xla_py_test( size = "medium", srcs = ["fused_batchnorm_test.py"], python_version = "PY3", + shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 92ea1cfaf87..eb8883c9ccd 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import itertools -import os import numpy as np @@ -73,8 +72,6 @@ class BinaryOpsTest(xla_test.XLATestCase): self.assertAllCloseAccordingToType( result[i], expected[i], rtol=rtol, atol=atol) - @test_util.disable_mlir_bridge( - "F16 type is not supported in CreateDenseElementsAttrFromLiteral") def testFloatOps(self): for dtype in self.float_types: if dtype == dtypes.bfloat16.as_numpy_dtype: @@ -1020,7 +1017,8 @@ class BinaryOpsTest(xla_test.XLATestCase): math_ops.matmul, np.array([[3.1415926535897932]], dtype=dtype), np.array([[2.7182818284590452]], dtype=dtype), - expected=np.array([[8.5397342226735668]], dtype=dtype)) + expected=np.array([[8.5397342226735668]], dtype=dtype), + rtol=1e-14) # Edge case with a large range of exponent. Not supported by float16. if dtype != np.float16: @@ -1028,7 +1026,8 @@ class BinaryOpsTest(xla_test.XLATestCase): math_ops.matmul, np.array([[9.4039548065783000e-38]], dtype=dtype), np.array([[4.5070591730234615e37]], dtype=dtype), - expected=np.array([[4.2384180773686798]], dtype=dtype)) + expected=np.array([[4.2384180773686798]], dtype=dtype), + rtol=1e-14) # TODO(phawkins): failing on GPU, no registered kernel. def DISABLED_testSparseMatMul(self): @@ -1098,8 +1097,6 @@ class BinaryOpsTest(xla_test.XLATestCase): x, expected=np.matmul(x, x.transpose([0, 1, 3, 2]))) - @test_util.disable_mlir_bridge( - "TODO(b/155097273): Handle complex dtype constants") def testExpandDims(self): for dtype in self.numeric_types: self._testBinary( @@ -1197,8 +1194,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.full([1, 1, 3, 5], 3., dtype=np.float32), expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32)) - @test_util.disable_mlir_bridge( - "TODO(b/155097273): Handle complex dtype constants") def testPad(self): for dtype, pad_type in itertools.product( self.numeric_types, [np.int32, np.int64]): @@ -1339,8 +1334,6 @@ class BinaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/155097273): Handle complex dtype constants") def testReshape(self): for dtype in self.numeric_types: self._testBinary( @@ -1473,8 +1466,6 @@ class BinaryOpsTest(xla_test.XLATestCase): [1, 2]], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/155097273): Handle complex dtype constants") def testTranspose(self): for dtype in self.numeric_types: self._testBinary( @@ -1493,8 +1484,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1, 3], [2, 4]], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/155097273): Handle complex dtype constants") def testConjugateTranspose(self): for dtype in self.complex_types: self._testBinary( @@ -1513,7 +1502,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype)) - @test_util.disable_mlir_bridge("Enable tf.Cross Compilation") def testCross(self): for dtype in self.float_types: self._testBinary( @@ -1592,8 +1580,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) - @test_util.disable_mlir_bridge( - "Requires BroadcastInDim method in MlirHloBuilder") def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) @@ -1604,29 +1590,16 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=x) self._testBinary( array_ops.broadcast_to, - x, - np.array([6, 6], dtype=np.int32), - expected=np.tile(x, [3, 2])) + np.zeros([2, 3], dtype=dtype), + np.array([2, 2, 3], dtype=np.int32), + expected=np.zeros([2, 2, 3], dtype=dtype)) + + x = np.arange(2).reshape((2, 1)).astype(dtype) self._testBinary( array_ops.broadcast_to, x, - np.array([7, 4, 3], dtype=np.int32), - expected=np.tile(x, [7, 2, 1])) - self._testBinary( - array_ops.broadcast_to, - x, - np.array([7, 0, 3], dtype=np.int32), - expected=np.zeros([7, 0, 3], dtype=dtype)) - self._testBinary( - array_ops.broadcast_to, - x, - np.array([7, 1, 2, 9], dtype=np.int32), - expected=np.tile(x, [7, 1, 1, 3])) - self._testBinary( - array_ops.broadcast_to, - np.zeros([2, 0], dtype=dtype), - np.array([4, 0], dtype=np.int32), - expected=np.zeros([4, 0], dtype=dtype)) + np.array([2, 2, 3], dtype=np.int32), + expected=np.tile(x, (2, 1, 3))) x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype) self._testBinary( @@ -1637,8 +1610,4 @@ class BinaryOpsTest(xla_test.XLATestCase): if __name__ == "__main__": - # TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems - os.environ[ - "XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false " + os.environ.get( - "XLA_FLAGS", "") googletest.main() diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 10dd2d6542c..f35ded924d5 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gradients_impl @@ -293,6 +294,7 @@ class ConcatTest(xla_test.XLATestCase): # The purpose of this is to ensure that XLA on GPU will not run out of memory # with too many arguments. + @test_util.disable_mlir_bridge("TODO(b/153895138): Debug.") def testConcatLargeNumberOfTensors(self): if "CPU" in self.device: self.skipTest("This test can time out on CPU, so we will just allow " diff --git a/tensorflow/compiler/tests/data_format_ops_test.py b/tensorflow/compiler/tests/data_format_ops_test.py index 681c1f3499e..08d44256b50 100644 --- a/tensorflow/compiler/tests/data_format_ops_test.py +++ b/tensorflow/compiler/tests/data_format_ops_test.py @@ -81,11 +81,21 @@ class XlaPermuteOpTest(xla_test.XLATestCase): x = np.array([7, 4, 9, 3], dtype=dtype) self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + def testNHWCToNCHW_Size2(self): + for dtype in {np.int32, np.int64}: + x = np.array([4, 9], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [4, 9]) + def testNCHWToNHWC(self): for dtype in {np.int32, np.int64}: x = np.array([7, 4, 9, 3], dtype=dtype) self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + def testNCHWToNHWC_Size2(self): + for dtype in {np.int32, np.int64}: + x = np.array([9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [9, 3]) + def testNHWCToHWNC(self): for dtype in {np.int32, np.int64}: x = np.array([7, 4, 9, 3], dtype=dtype) diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 6a9076e9be8..a36effe5984 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.compiler.tests import test_utils from tensorflow.compiler.tests import xla_test -from tensorflow.python.compat import compat from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker @@ -132,9 +131,6 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): def _testLearning(self, use_gradient_checker, data_format, exponential_avg_factor): - if not compat.forward_compatible(2020, 3, - 6) and exponential_avg_factor != 1.0: - self.skipTest("running average not available.") channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index a2a47f19a6e..7bbfecff403 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -24,6 +24,7 @@ import scipy.special as sps from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -47,6 +48,8 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): {'start': 1, 'end': 2, 'num': 1}, {'start': 1, 'end': 4, 'num': 3}, {'start': 0, 'end': 41, 'num': 42}) + @test_util.disable_mlir_bridge( + 'TODO(b/156174708): Dynamic result types not supported') def testLinspace(self, start, end, num): expected = np.linspace(start, end, num, dtype=np.float32) result = self._testTernary( @@ -211,6 +214,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): upper, expected=np.minimum(np.maximum(x, lower), upper)) + @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetaincSanity(self): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: @@ -248,6 +252,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): 'atol': 2e-4 }, ) + @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetainc(self, sigma, rtol, atol): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 0c4c7bacdf3..85bf89c4f9e 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -186,8 +186,6 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.Softmax compilation") def testFloatOps(self): for dtype in self.float_types: x = np.arange(-0.90, 0.90, 0.25) @@ -349,17 +347,15 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype)) - # TODO(b/130689556): Turn this on for CPU when we start honoring NaNs. - if self.device != "XLA_CPU": - self._assertOpOutputMatchesExpected( - math_ops.tanh, - np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], - [19, -19, 22, -22]], - dtype=dtype), - expected=np.array( - [[0.76159418, 0.96402758, 0.99505478, 0.99932933], - [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], + [19, -19, 22, -22]], + dtype=dtype), + expected=np.array( + [[0.76159418, 0.96402758, 0.99505478, 0.99932933], + [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], + dtype=dtype)) self._assertOpOutputMatchesExpected( nn_ops.log_softmax, @@ -514,6 +510,11 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") + def testQuantizeAndDequantize(self): + for dtype in self.float_types: + def quantize_and_dequantize_v2(x): return array_ops.quantize_and_dequantize_v2( x, -127, 127, signed_input=True, num_bits=8) @@ -598,8 +599,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) - @test_util.disable_mlir_bridge( - "Complex types not supported in CreateDenseElementsAttrFromLiteral") def testComplexOps(self): for dtype in self.complex_types: diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 3b304df9024..f3e915daa67 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -51,7 +51,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): equality_fn = self.assertAllClose equality_fn(result, expected, rtol=1e-3) - @test_util.disable_mlir_bridge('Not supported yet') def testAdd(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -72,7 +71,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): np.array([7, 11], dtype=dtype)), expected=np.array([[8, 13], [10, 15]], dtype=dtype)) - @test_util.disable_mlir_bridge('Not supported yet') def testBroadcast(self): for dtype in self.numeric_types: v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) @@ -110,7 +108,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): xla_data_pb2.PrecisionConfig.HIGHEST) @parameterized.parameters(*PRECISION_VALUES) - @test_util.disable_mlir_bridge('Not supported yet') def testConv(self, precision): for dtype in set(self.float_types).intersection( set([dtypes.bfloat16.as_numpy_dtype, np.float32])): @@ -195,8 +192,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([1, 2, 3], dtype=dtype),), expected=np.array([-1, -2, -3], dtype=dtype)) - @test_util.disable_mlir_bridge( - 'Requires XlaPad op shape inference to have static result types') def testPad(self): for dtype in self.numeric_types: @@ -309,7 +304,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) - @test_util.disable_mlir_bridge('Not supported yet') def testDynamicSlice(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -322,7 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [[673, 674], [683, 684], [693, 694]]]), dtype=dtype)) - @test_util.disable_mlir_bridge('Not supported yet') + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectStartIndicesShape(self): with self.session() as session: with self.test_scope(): @@ -336,7 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) - @test_util.disable_mlir_bridge('Not supported yet') + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectSizeIndicesShape(self): with self.session() as session: with self.test_scope(): diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 356798c19bd..3d3eab51268 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -403,6 +403,7 @@ tf_cuda_library( "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/optimizers:meta_optimizer", "//tensorflow/stream_executor/lib", + "//tensorflow/tools/graph_transforms:transform_utils", ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(), alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 806d930b76f..414d27477bc 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -41,18 +41,16 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/devices.h" -#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT #include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -90,8 +88,6 @@ bool AllowDynamicNonBatchDimension(const ConversionParams& params) { GetEngineType(params) == EngineInfo::EngineType::TRTDynamic; } -} // namespace - struct EdgePtrCompare { bool operator()(const Edge* lhs, const Edge* rhs) const { return lhs->id() < rhs->id(); @@ -555,6 +551,58 @@ Status CreateTRTNode(const ConversionParams& params, return Status::OK(); } +int64 GetNextGraphSequenceNumber() { + static std::atomic graph_sequence_num; + return graph_sequence_num++; +} + +constexpr char kCastInputTypeAttrName[] = "SrcT"; + +// Transforms node = cast(x, fp32) where datatype(x) != fp16 to: +// castToFp16 = cast(x, fp16) +// node = cast(castToFp16, fp32) +// +Status MaybeRewriteCastToFp32(GraphDef* graph_def, NodeDef* node_def) { + if (node_def->op() != "Cast") { + return Status::OK(); + } + + DataTypeVector input_types; + DataTypeVector output_types; + TF_RETURN_IF_ERROR( + graph_transforms::GetInOutTypes(*node_def, &input_types, &output_types)); + + if (input_types.size() != 1 || output_types.size() != 1) { + return errors::Internal("Bad cast operation"); + } + + if (input_types[0] == DT_HALF || output_types[0] != DT_FLOAT) { + return Status::OK(); + } + + VLOG(2) << "Rewriting cast to FP32 " << node_def->DebugString(); + + NodeDef* castToFp16 = graph_def->add_node(); + for (auto attr_value : node_def->attr()) { + (*castToFp16->mutable_attr())[attr_value.first] = attr_value.second; + } + castToFp16->set_name(node_def->name() + "_split"); + castToFp16->set_op("Cast"); + castToFp16->set_device(node_def->device()); + castToFp16->add_input(node_def->input(0)); + (*castToFp16->mutable_attr())[kCastOutputTypeAttrName].set_type(DT_HALF); + + node_def->set_input(0, castToFp16->name() + ":0"); + (*node_def->mutable_attr())[kCastInputTypeAttrName].set_type(DT_HALF); + + VLOG(2) << castToFp16->DebugString(); + VLOG(2) << node_def->DebugString(); + + return Status::OK(); +} + +} // namespace + Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def, Graph* graph, const string& engine_name) { Graph segment_graph(graph->flib_def()); @@ -629,11 +677,6 @@ std::pair GetDeviceAndAllocator(const ConversionParams& params, return std::make_pair(cuda_device_id, dev_allocator); } -int64 GetNextGraphSequenceNumber() { - static std::atomic graph_sequence_num; - return graph_sequence_num++; -} - // Entry function from optimization pass. Status ConvertAfterShapes(const ConversionParams& params) { // Sanity checks. @@ -643,12 +686,43 @@ Status ConvertAfterShapes(const ConversionParams& params) { "Calibration with FP32 or FP16 is not supported."); } + // Make a copy of the input_graph_def because grappler doesn't allow changes + // to the input_graph_def and GraphProperties only accepts GraphDef, but not + // Graph, as inputs. + // + // If the overhead of copying the input_graph_def becomes a concern, we can + // avoid the copy by (1) enhancing the GraphPropertiers representation to + // allow adding shape properties for newly created graph nodes and (2) rewrite + // the GraphDef transformation to Graph transformation. + GraphDef modified_graph_def = params.grappler_item->graph; + // When precision_mode is FP16, transform cast(x, fp32) to + // cast(cast(x, fp16), fp32). This creates cast(fp16, f32) that can be + // included in the TRTEngineOp as an TensorRT Identity layer for performance: + // . Avoid cast(fp32, fp16) in the TRT engine implementation for fp16 + // precision. + // . Changing the input to the TRTEngine from fp32 to fp16 may reduce data + // moving from the host to the GPU. + if (params.precision_mode == TrtPrecisionMode::FP16) { + for (int i = 0; i < modified_graph_def.node_size(); i++) { + NodeDef* node_def = modified_graph_def.mutable_node(i); + TF_RETURN_IF_ERROR(MaybeRewriteCastToFp32(&modified_graph_def, node_def)); + } + } + + // Construct a GrapplerItem using the modified graph_def and the input + // grappler_item. + grappler::GrapplerItem grappler_item = + params.grappler_item->WithGraph(std::move(modified_graph_def)); + const GraphDef& graph_def = grappler_item.graph; + + grappler::GraphProperties static_graph_properties(grappler_item); + TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); + // Convert graphdef to graph. - FunctionLibraryDefinition flib(OpRegistry::Global(), - params.input_graph_def->library()); + FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library()); Graph graph(flib); - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), - *params.input_graph_def, &graph)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph)); // Segment the graph into subgraphs that can be converted to TensorRT segment::SegmentOptions segment_options; @@ -662,10 +736,10 @@ Status ConvertAfterShapes(const ConversionParams& params) { AllowDynamicNonBatchDimension(params); segment::SegmentNodesVector initial_segments; - TrtNodeValidator validator(*params.graph_properties, params.precision_mode, + TrtNodeValidator validator(static_graph_properties, params.precision_mode, params.use_calibration, params.use_implicit_batch); TF_RETURN_IF_ERROR(segment::SegmentGraph( - &graph, params.graph_properties, + &graph, &static_graph_properties, std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator, std::placeholders::_1), // Input validation is already done by TrtNodeValidator, so we don't @@ -693,9 +767,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { auto& curr_segment = initial_segments.at(t); EngineInfo curr_engine; curr_engine.engine_name = StrCat(engine_name_prefix, t); - Status status = - GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map, - reverse_topo_order, &curr_engine); + Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment, + node_map, reverse_topo_order, &curr_engine); if (!status.ok()) { LOG(WARNING) << "Failed to get engine info for segment " << t << ": " << status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 2bfaa2a786c..53ab84a6fa9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -18,10 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" -#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" -#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -33,7 +32,7 @@ namespace tensorrt { namespace convert { struct ConversionParams { - const GraphDef* input_graph_def = nullptr; + const grappler::GrapplerItem* grappler_item = nullptr; const std::vector* output_names = nullptr; string trt_logger_name; size_t max_batch_size = 1; @@ -41,7 +40,6 @@ struct ConversionParams { GraphDef* output_graph_def = nullptr; TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32; int minimum_segment_size = 3; - const grappler::GraphProperties* graph_properties = nullptr; const grappler::Cluster* cluster = nullptr; // Whether to create engine on conversion or execution time bool is_dyn_op = false; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 2cfefd27a67..a1f523d6bfa 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -162,12 +162,11 @@ class ConvertAfterShapesTest : public ::testing::Test { // Construct ConversionParams. const std::vector output_names{"output"}; ConversionParams params; - params.input_graph_def = &item.graph; params.output_names = &output_names; params.max_workspace_size_bytes = 8 << 20; params.output_graph_def = output_graph_def; params.minimum_segment_size = 1; - params.graph_properties = &graph_properties; + params.grappler_item = &item; params.use_calibration = false; params.trt_logger_name = "DefaultLogger"; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 4aec6a2512c..8ca7c4cdf8f 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" @@ -403,6 +404,18 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, // Compare broadcast feasibility if (check_feasibility) { for (int i = 0; i < broadcast_num_dims; ++i) { + if (!use_implicit_batch && (output_l[i] == -1 || output_r[i] == -1)) { + // If the condition is true then we are in explicit batch mode and (at + // least) one of the input dimensions are unknown. In other words we + // are in dynamic shape mode. During conversion time we only see -1 for + // the unknown shapes, therefore we cannot decide on the feasibility of + // broadcast over the unknown dimensions. Therefore we just continue for + // the next dimension. In dynamic shape mode TRT can only check the + // feasibility of the broadcast when the actual input dimensions are + // specified by SetTrtEngineInputs and the inference job is launched by + // TrtEnque. + continue; + } if ((output_l[i] != output_r[i]) && (output_l[i] != 1) && (output_r[i] != 1)) { return errors::InvalidArgument("Infeasible broadcast scheme (", @@ -795,6 +808,19 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { } } +Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const { + if (is_tensor()) { + nvinfer1::DataType trt_type = tensor()->getType(); + return TrtTypeToTfType(trt_type, tf_type); + } + + if (is_weights()) { + *tf_type = weights().GetTensor().dtype(); + return Status::OK(); + } + return errors::Internal("The object is probably not initialized"); +} + string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { @@ -1900,27 +1926,48 @@ Status CheckInputsWeights( return Status::OK(); } -Status AllowDataTypes(const OpConverterParams& params, - const std::set& allowed_dtypes, - const char* dtype_attr_name = "T") { - const auto& node_def = params.node_def; +Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type, + const char* type_attr_name) { TFAttrs attrs(node_def); - if (!attrs.count(dtype_attr_name)) { - return errors::InvalidArgument("Attribute with name ", dtype_attr_name, + if (!attrs.count(type_attr_name)) { + return errors::InvalidArgument("Attribute with name ", type_attr_name, " not found."); } - const auto op_dtype = attrs.get(dtype_attr_name); - if (!allowed_dtypes.count(op_dtype)) { - // Build string list of allowed types. - std::ostringstream ss; - for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { - if (it != allowed_dtypes.begin()) ss << ", "; - ss << DataTypeString(*it); - } - return errors::Unimplemented("Data type ", DataTypeString(op_dtype), + *tf_type = attrs.get(type_attr_name); + return Status::OK(); +} + +Status GetInputTfType(const OpConverterParams& params, DataType* tf_type, + int pos) { + const std::vector& inputs = params.inputs; + if (inputs.size() <= pos) { + return errors::Internal("Invalid input position"); + } + + return inputs[pos].GetTfType(tf_type); +} + +constexpr const char kOutputTypeAttrName[] = "T"; + +Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) { + return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName); +} + +Status AllowDataTypes(const OpConverterParams& params, + const std::set& allowed_types, + const char* type_attr_name = kOutputTypeAttrName) { + const auto& node_def = params.node_def; + DataType tf_type; + TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name)); + if (!allowed_types.count(tf_type)) { + string allowed_types_string = absl::StrJoin( + allowed_types, ", ", [](string* out, const DataType& type) { + absl::StrAppendFormat(out, "%s", DataTypeString(type)); + }); + return errors::Unimplemented("Data type ", DataTypeString(tf_type), " is not supported for ", node_def.op(), - ", must be one of [", ss.str(), "], at ", - node_def.name()); + ", must be one of [", allowed_types_string, + "], at ", node_def.name()); } return Status::OK(); } @@ -2111,6 +2158,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, "Stride must be 1 for batch and channel dimensions, at ", node_def.name()); } + // Channel dim must be static for DepthwiseConv2dNative since we use that + // value for num_groups at build time. + if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) { + return errors::InvalidArgument("Channel dimension must be static, at ", + node_def.name()); + } const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); if (params->validation_only) return Status::OK(); @@ -2122,11 +2175,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, } // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); + const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1]; // group == 0 signifies that this is a depthwise convolution, so set // num_groups to size of input's channel dim. For a non-depthwise conv, // num_groups will be 1. - const int num_groups = (group == 0) ? tensor_dim.d[0] : group; + const int num_groups = (group == 0) ? c_dim_size : group; // For conv, TF weights are RSCK, and TRT expects KCRS. // For backprop, TF weights are RSKC, and TRT expects CKRS. @@ -2413,26 +2467,19 @@ Status ConvertExpandDims(OpConverterParams* params) { } Status Converter::SqueezeTensor(nvinfer1::ITensor* input, - const std::vector& trt_axes, + std::vector* input_dims, nvinfer1::ITensor** output) { - const nvinfer1::Dims dims = input->getDimensions(); - std::vector input_dims(dims.d, dims.d + dims.nbDims); - // Mark axes to remove by setting them to 0. - for (int axis : trt_axes) { - input_dims[axis] = 0; - } - #if IS_TRT_VERSION_GE(6, 0, 0, 0) // If the remaining dimensions of a squeeze operation have dynamic sizes, we // need to use TRT ops to build the result shape for the squeeze operation. // This is because IShuffleLayer::setReshapeDimensions treats -1 as a special // value. - if (absl::c_any_of(input_dims, [](int i) { return i == -1; })) { + if (absl::c_any_of(*input_dims, [](int i) { return i == -1; })) { nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0); std::vector concat_inputs; - for (int i = 0; i < input_dims.size(); i++) { + for (int i = 0; i < input_dims->size(); i++) { // If input dim wasn't set to 0 earlier, we include it in new shape. - if (input_dims[i] != 0) { + if (input_dims->at(i) != 0) { concat_inputs.push_back( network() ->addSlice(*shape, {1, {i}}, {1, {1}}, {1, {1}}) @@ -2452,11 +2499,12 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input, } #endif // Remove all dims which are equal to 0. - input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), - input_dims.end()); + input_dims->erase(std::remove(input_dims->begin(), input_dims->end(), 0), + input_dims->end()); // Reshape tensor. nvinfer1::Dims new_dims; - TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); + VLOG(2) << "input_dims" << input_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims)); TF_RETURN_IF_ERROR(PrepareTensorForShape(TRT_TensorOrWeights(input), new_dims, /*validation_only=*/false, output)); return Status::OK(); @@ -2475,31 +2523,48 @@ Status ConvertSqueeze(OpConverterParams* params) { TFAttrs attrs(node_def); auto squeeze_dims = attrs.get>("squeeze_dims"); if (squeeze_dims.empty()) { - return errors::Unimplemented( - "Squeeze is only implemented for explicit dims, at ", node_def.name()); - } - std::vector trt_axes; - trt_axes.reserve(squeeze_dims.size()); - for (int tf_axis : squeeze_dims) { - // If the axis is valid, then convert it to TRT axis, otherwise abort - // conversion. - int trt_axis; - TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), - params->use_implicit_batch, &trt_axis)); - // Make sure target dimension is size 1 or unknown size (-1) - if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) { - return errors::InvalidArgument( - "Dimension ", tf_axis, " with size ", input_dims[trt_axis], - " cannot be squeezed because it must be size 1, at ", + if (params->use_implicit_batch || !HasStaticShape(dims)) { + return errors::Unimplemented( + "Squeeze is not implemented for empty squeeze_dims, at ", node_def.name()); + } else { + // explicit batch mode with static input shape we squeeze all singleton + // dimensions + for (int& dim : input_dims) { + if (dim == 1) { + // Mark it for removal by setting it to 0 + dim = 0; + } + } + } + } else { + std::vector trt_axes; + trt_axes.reserve(squeeze_dims.size()); + for (int tf_axis : squeeze_dims) { + // If the axis is valid, then convert it to TRT axis, otherwise abort + // conversion. + int trt_axis; + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + params->use_implicit_batch, &trt_axis)); + // Make sure target dimension is size 1 or unknown size (-1) + if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) { + return errors::InvalidArgument( + "Dimension ", tf_axis, " with size ", input_dims[trt_axis], + " cannot be squeezed because it must be size 1, at ", + node_def.name()); + } + trt_axes.push_back(trt_axis); + } + // Mark axes to remove by setting them to 0. + for (int axis : trt_axes) { + input_dims[axis] = 0; } - trt_axes.push_back(trt_axis); } if (params->validation_only) return Status::OK(); nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->SqueezeTensor( - input_tensor.tensor(), trt_axes, &output_tensor)); + input_tensor.tensor(), &input_dims, &output_tensor)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -3842,11 +3907,26 @@ Status ConvertBiasAdd(OpConverterParams* params) { nvinfer1::Dims input_shape = inputs.at(0).GetTrtDims(); nvinfer1::Dims bias_shape = inputs.at(1).GetTrtDims(); - // If the input is NCHW, then we need to unsqueeze the bias such that its last - // dimensions are 1s (and the first dimension is C). + // The bias input arg is a 1-D tensor with length C. If the input is NCHW, + // then we need to unsqueeze the bias such that its shape is [1, C, 1, 1]. if (data_format == "NCHW") { - bias_shape.nbDims = input_shape.nbDims; - std::fill(bias_shape.d + 1, bias_shape.d + bias_shape.nbDims, 1); + if (params->use_implicit_batch) { + // The batch dim is not included in implicit batch mode, so the shape of + // the bias tensor is [C, 1, 1]. + bias_shape.nbDims = input_shape.nbDims; + std::fill(bias_shape.d + 1, bias_shape.d + bias_shape.nbDims, 1); + } else { + // In explicit batch mode we create a tensor with shape [1, C, 1, 1]. + std::vector bias_shape_vec(bias_shape.d, + bias_shape.d + bias_shape.nbDims); + // Insert 1 before for batch dim + bias_shape_vec.insert(bias_shape_vec.begin(), 1); + // Trail with 1s to match input_shape size + bias_shape_vec.insert(bias_shape_vec.end(), + input_shape.nbDims - bias_shape_vec.size(), 1); + TF_RETURN_IF_ERROR( + TensorShapeArrayToTrtDims(bias_shape_vec, &bias_shape)); + } } else { // Next, broadcast the bias across the input. TF_RETURN_IF_ERROR(GetTrtBroadcastShape(inputs.at(0), inputs.at(1), @@ -4587,6 +4667,44 @@ Status ConvertUnpack(OpConverterParams* params) { return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true); } +// Supports cast fp16=>fp32 through IIdentityLayer. +Status ConvertCast(OpConverterParams* params) { + const NodeDef& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + auto unsupport_cast_error = [&]() { + return errors::Unimplemented("Cast op: ", node_def.op(), + " not supported at: ", node_def.name()); + }; + + DataType input_type; + TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0)); + if (input_type != DataType::DT_HALF) { + return unsupport_cast_error(); + } + + DataType output_type; + TF_RETURN_IF_ERROR(GetNodeDefTfType(params->node_def, &output_type, + kCastOutputTypeAttrName)); + + if (output_type != DataType::DT_FLOAT) { + return unsupport_cast_error(); + } + + if (params->validation_only) return Status::OK(); + + nvinfer1::ITensor* input = params->inputs.at(0).tensor(); + nvinfer1::IIdentityLayer* layer = + params->converter->network()->addIdentity(*input); + layer->setPrecision(nvinfer1::DataType::kFLOAT); + + if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) { + return errors::Internal("IIdentityLayer doesn't work as expected"); + } + + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + Status ConvertConcat(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -5664,6 +5782,7 @@ static void RegisterValidatableOpConverters( (*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS; #endif (*registration)["AddN"] = ConvertAddN; + (*registration)["Cast"] = ConvertCast; (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; (*registration)["Conv2D"] = ConvertConv2D; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 8608c8226ee..2fe8eec9675 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -294,6 +294,8 @@ class TRT_TensorOrWeights { nvinfer1::Dims GetTrtDims() const; + Status GetTfType(DataType* tf_type) const; + int batch_size() const { return batch_size_; } string DebugString() const; @@ -529,11 +531,9 @@ class Converter { // Helper function to add a squeeze op to the network. // - // The trt_axes argument lists those axes that need to be squeezed. Each axis - // in the list is numbered according to TRT convention (see ConvertAxis for - // details). - Status SqueezeTensor(nvinfer1::ITensor* input, - const std::vector& trt_axes, + // The input_dims argument stores the TRT dimensions of the input tensor, + // where the dimensions to be squeezed are replaced by 0. + Status SqueezeTensor(nvinfer1::ITensor* input, std::vector* input_dims, nvinfer1::ITensor** output); // Creates an IConstantLayer using 'weights' whose dimensions are specified by diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 1f30b837450..f9b0cafe253 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include +#include #include #include #include @@ -67,7 +68,6 @@ using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; using ::testing::Matcher; -using ::testing::NanSensitiveFloatNear; // TensorRT modes for testing. We define the following three modes: // 1. Implicit batch mode: The tensors have static (known) input shape and the @@ -135,30 +135,18 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } -nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { - switch (tf_dtype) { - case DT_FLOAT: - return nvinfer1::DataType::kFLOAT; - case DT_HALF: - return nvinfer1::DataType::kHALF; - case DT_INT32: - return nvinfer1::DataType::kINT32; - default: - QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype); - } +nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) { + nvinfer1::DataType trt_type; + Status status = TfTypeToTrtType(tf_type, &trt_type); + EXPECT_EQ(status, Status::OK()); + return trt_type; } -DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - return DT_FLOAT; - case nvinfer1::DataType::kHALF: - return DT_HALF; - case nvinfer1::DataType::kINT32: - return DT_INT32; - default: - QCHECK(false) << "Unexpected data type " << static_cast(trt_dtype); - } +DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) { + DataType tf_type; + Status status = TrtTypeToTfType(trt_type, &tf_type); + EXPECT_EQ(status, Status::OK()); + return tf_type; } NodeDef MakeNodeDef(const string& name, const string& op, @@ -216,6 +204,24 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, << " actual: " << DebugString(rhs); } +Matcher> ArrayFloatNear(const std::vector& values, + float max_abs_error = 1e-5, + bool nan_sensitive = false) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + if (nan_sensitive) { + matchers.emplace_back(::testing::NanSensitiveFloatNear(v, max_abs_error)); + } else if (max_abs_error == 0) { + matchers.emplace_back(::testing::FloatEq(v)); + } else { + EXPECT_GE(max_abs_error, 0); + matchers.emplace_back(::testing::FloatNear(v, max_abs_error)); + } + } + return ElementsAreArray(matchers); +} + template void ExpectArrayNear(const std::vector& lhs, absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); @@ -293,7 +299,8 @@ struct StaticCaster { }; template -std::vector CastTestVector(const std::vector& vals) { +std::vector CastTestVector( + const gtl::ArraySlice& vals) { // non-absl ok std::vector res(vals.size()); std::transform(vals.begin(), vals.end(), res.begin(), StaticCaster()); @@ -1283,6 +1290,21 @@ inline absl::Span GetSpanForData(const InputOutputData& data) { return absl::Span(tensor_map.data(), tensor_map.size()); } +std::vector GetDataAsFloat(InputOutputData& data) { + if (data.tensor.dtype() == DT_FLOAT) { + auto span = GetSpanForData(data); + return std::vector(span.begin(), span.end()); + } + if (data.tensor.dtype() == DT_HALF) { + return CastTestVector( + GetSpanForData(data)); + } + if (data.tensor.dtype() == DT_INT32) { + return CastTestVector(GetSpanForData(data)); + } + LOG(FATAL) << "DataType not supported for testing " + << DataTypeString(data.tensor.dtype()); +} // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -1336,6 +1358,33 @@ class OpConverterTest : public ::testing::Test { return ret; } + // Constructs a tensor with given values (vals). The tensor type is defined by + // the tf_dtype argument, its shape is given by input_dims. The tensor is + // constructed using the allocator of OpConverterTest in Unified Memory. + template + Tensor AsTensor(std::vector vals, const std::vector input_dims, + DataType tf_dtype) { + Tensor ret(allocator_.get(), tf_dtype, {static_cast(vals.size())}); + if (tf_dtype == DT_FLOAT) { + auto conv_vals = CastTestVector(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat().data()); + } else if (tf_dtype == DT_HALF) { + auto conv_vals = CastTestVector(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), + ret.flat().data()); + } else if (tf_dtype == DT_INT32) { + auto conv_vals = CastTestVector(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat().data()); + } else { + LOG(FATAL) << "Cannot create tensor with type " + << DataTypeString(tf_dtype); + } + TensorShape shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_dims, &shape)); + CHECK(ret.CopyFrom(ret, shape)); + return ret; + } + // Constructs a flat tensor in Unified Memory. template Tensor ConstructTensor(int data_size, const T& value = T()) { @@ -1343,6 +1392,13 @@ class OpConverterTest : public ::testing::Test { return AsTensor(values); } + // Constructs a flat tensor in Unified Memory. + template + Tensor ConstructTensor(int data_size, const T& value, DataType tf_dtype) { + std::vector values(data_size, value); + return AsTensor(values, {data_size}, tf_dtype); + } + void CheckDataTypeMatches(const DataVec& datas) { for (const auto& data : datas) { const int input_index = engine_->getBindingIndex(data.name.c_str()); @@ -1356,27 +1412,29 @@ class OpConverterTest : public ::testing::Test { } } - void BuildAndRun(const DataVec& input_data, DataVec* output_data, - const int batch_size = 1) { + Status BuildAndRun(const DataVec& input_data, DataVec* output_data, + const int batch_size = 1) { // Mark the output tensor as TRT engine output. std::vector output_info; for (const auto& data : *output_data) { output_info.push_back( {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); } - TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); + TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. - ASSERT_EQ(nullptr, engine_.get()); + if (engine_.get() != nullptr) { + return errors::Internal("Engine already exists"); + } TrtShapeOptimizationProfile profiles; if (!converter_->use_implicit_batch()) { // Create a single optimization profile for explicit batch mode std::vector input_shapes; - TF_ASSERT_OK(GetShapeFromDataVec(input_data, &input_shapes)); + TF_RETURN_IF_ERROR(GetShapeFromDataVec(input_data, &input_shapes)); profiles.AddShape(input_shapes); profiles.InitProfiles(); } - TF_ASSERT_OK( + TF_RETURN_IF_ERROR( converter_->BuildCudaEngine(&engine_, /*max_batch_size=*/batch_size, /*max_workspace_size_bytes=*/1 << 26, @@ -1390,7 +1448,9 @@ class OpConverterTest : public ::testing::Test { const int num_bindings = input_data.size() + output_data->size(); std::vector buffers(num_bindings); - ASSERT_EQ(engine_->getNbBindings(), num_bindings); + if (engine_->getNbBindings() != num_bindings) { + return errors::Internal("Number of bindings do not match"); + } // Since we have only 1 optimization profile (which is enabled by default) // it is fine to create execution context directly, instead of calling // profiles.CreateExecutionContexts() @@ -1398,19 +1458,19 @@ class OpConverterTest : public ::testing::Test { engine_->createExecutionContext()); // Prepare input bindings. - TF_ASSERT_OK(SetTrtEngineInputs(engine_.get(), execution_context.get(), 0, - buffers, converter_->use_implicit_batch(), - batch_size, nullptr, &input_data)); - + TF_RETURN_IF_ERROR(SetTrtEngineInputs( + engine_.get(), execution_context.get(), 0, buffers, + converter_->use_implicit_batch(), batch_size, nullptr, &input_data)); // Prepare output bindings. - TF_ASSERT_OK(SetTrtEngineOutputs(engine_.get(), execution_context.get(), 0, - buffers, converter_->use_implicit_batch(), - batch_size, nullptr, output_data)); - + TF_RETURN_IF_ERROR(SetTrtEngineOutputs( + engine_.get(), execution_context.get(), 0, buffers, + converter_->use_implicit_batch(), batch_size, nullptr, output_data)); // Execute the TRT engine. - TF_ASSERT_OK(TrtEnqueue(execution_context.get(), buffers, stream_, - converter_->use_implicit_batch(), batch_size)); + TF_RETURN_IF_ERROR(TrtEnqueue(execution_context.get(), buffers, stream_, + converter_->use_implicit_batch(), + batch_size)); cudaStreamSynchronize(stream_); + return Status::OK(); } bool HasStaticShape(const nvinfer1::Dims& dims) const { @@ -1427,7 +1487,7 @@ class OpConverterTest : public ::testing::Test { // Adds ITensor for both validation and conversion, assuming explicit batch // dimension is included in dims (ie for an NCHW tensor dims = {N, C, H, W}). - void AddTestTensorWithExplicitBatchDim( + void AddTestTensorWithTFDims( const string& name, const std::vector& dims, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { DataType tf_dtype = TrtDataTypeToTf(trt_dtype); @@ -1447,54 +1507,19 @@ class OpConverterTest : public ::testing::Test { } } - // Adds ITensor for both validation and conversion. The tensor can have - // partial input shape. This function defines static or dynamic shape input - // tensor for the network based on the trt_mode attribute. This is done - // automatically, unless the user overrides it with an explicit - // partial_input_shape_dims argument. - // - // Parameters: - // - dims actual dimensions of the tensor that we will use during the test - // (including explicit batch dim). This is not used if partial_input_shape - // is defined. - // - partial_input_shape dimensions which can incude unknown shapes. This can - // be empty, in that case the partial_input_shape will be set automatically - // depending on the trt_mode argument. (This also includse explicit batch - // dim). - // - // On return skip_test is false if trt_mode is not compatible with the - // partial input shape. - void AddTestTensor( - const string& name, const std::vector& dims, - nvinfer1::DataType trt_dtype, TrtTestMode trt_mode, - const std::vector* partial_input_shape_dims = nullptr) { - std::vector partial_shape; - if (partial_input_shape_dims && !partial_input_shape_dims->empty()) { - partial_shape = *partial_input_shape_dims; - } else { - if (trt_mode == TrtTestMode::kDynamicShape) { - // In dynamic shape mode we set the all dims unknown. - partial_shape = std::vector(dims.size(), -1); - } else { - // Use static (known) input shapes. - partial_shape = dims; - } - } - AddTestTensorWithExplicitBatchDim(name, partial_shape, trt_dtype); - } - // Adds ITensor for both validation and conversion. The difference compared to - // AddTestTensorWithExplicitBatchDim is in the meaning of the dims parameter. - // To define a tensor with NCHW shape, here we set dims = {C,H,W} and - // batch_size = N. TODO(tfeher) remove this function once all test are updated - // to use the other version of AddTestTensor which has the trt_mode arg. + // AddTestTensorWithTFDims is in the meaning of the dims parameter. To define + // a tensor with NCHW shape, here we set dims = {C,H,W} and batch_size = N. + // TODO(tfeher) remove this function once all test are updated to use the + // other version of AddTestTensor (defined by + // ParameterizedOpConverterTestBase). void AddTestTensor( const string& name, const std::vector& dims, int batch_size = 1, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { std::vector dims_with_batch(dims.size() + 1); dims_with_batch[0] = batch_size; std::copy(dims.begin(), dims.end(), dims_with_batch.begin() + 1); - AddTestTensorWithExplicitBatchDim(name, dims_with_batch, trt_dtype); + AddTestTensorWithTFDims(name, dims_with_batch, trt_dtype); if (HasStaticShape(dims)) { ASSERT_EQ(batch_size, converter_->batch_size_); } @@ -1527,6 +1552,21 @@ class OpConverterTest : public ::testing::Test { converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights})); } + template + void AddTestWeights(const string& name, const std::vector& dims, + const std::vector& values, DataType tf_dtype) { + if (tf_dtype == DT_FLOAT) { + AddTestWeights(name, dims, CastTestVector(values)); + } else if (tf_dtype == DT_HALF) { + AddTestWeights(name, dims, CastTestVector(values)); + } else if (tf_dtype == DT_INT32) { + AddTestWeights(name, dims, CastTestVector(values)); + } else { + FAIL() << "Cannot create test weights with type " + << DataTypeString(tf_dtype); + } + } + // Test validation in validation-only mode. void RunValidation(const Node* node, error::Code expected_code = error::OK, const char* expected_msg_substr = nullptr) { @@ -1664,20 +1704,146 @@ std::ostream& operator<<(std::ostream& os, const TestParamBase& p) { return os; } -// Parameterized version of OpConverterTest. This class will be instantiated -// to test all the TrtTestModes but only in FP32 precision. This means that we -// will use the following combinations of test parameters: +// Parameterized version of OpConverterTest. We have the following parameters: // 1. TrtTestMode: implicit batch, explicit batch, dynamic shape modes -// 2. DataType of the input TF tensors: DT_FLOAT -// 3. TrtPrecisionMode argument for the Converter: FP32 -class ParameterizedOpConverterTest +// 2. DataType of the input TF tensors: DT_FLOAT, DT_HALF, DT_INT32 +// 3. TrtPrecisionMode argument for the Converter: FP32, FP16, INT8 +// We will introduce subclasses that will be instantiated using different +// combinations of the DataType and TrtPrecisionMode parameters. +class ParameterizedOpConverterTestBase : public OpConverterTest, public ::testing::WithParamInterface< - std::tuple> {}; + std::tuple> { + public: + ParameterizedOpConverterTestBase() + : trt_mode(std::get<0>(GetParam())), + tf_dtype(std::get<1>(GetParam())), + converter_precision(std::get<2>(GetParam())) {} -// Instantiate parameter combinations to test. For debugging purposes it might -// make sense to run over all possible combinations, but normally a subset of -// them would be sufficient: + void Reset() { + OpConverterTest::Reset(converter_precision, trt_mode); + input_data_.clear(); + } + + // Adds an input ITensor for TRT network. Also creates the corresponding TF + // tensor, and stores it in the list of inputs (input_data_). + // + // The TF tensor is always created with concrete static input shape given by + // dims. The ITensor can have static or dynamic shape based on the trt_mode + // attribute. The ITensor shape is set automatically according to the trt_mode + // parameter, unless the user overrides it with an explicit + // partial_input_shape_dims argument. + // + // Parameters: + // - name of the input node + // - dims actual dimensions of the tensor that we will use during the test + // (including explicit batch dim) + // - values initial values for the TF tensor + // - dtype data type of the tensor + // - partial_input_shape dimensions which can incude unknown shapes. This can + // be empty, in that case the partial_input_shape will be set automatically + // depending on the trt_mode argument. (This argument also includes explicit + // batch dim). + // + template + void AddTestTensor(const string& name, const std::vector& dims, + DataType tf_dtype, const std::vector& values, + const std::vector& partial_input_shape_dims = {}) { + std::vector partial_shape; + if (!partial_input_shape_dims.empty()) { + partial_shape = partial_input_shape_dims; + } else { + if (trt_mode == TrtTestMode::kDynamicShape) { + // In dynamic shape mode we make all dims unknown. + partial_shape = std::vector(dims.size(), -1); + } else { + // Use static (known) input shapes. + partial_shape = dims; + } + } + AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_dtype)); + if (!values.empty()) { + VLOG(2) << "Adding test tensor: " << name << " " + << DataTypeString(tf_dtype); + InputOutputData data{name, AsTensor(values, dims, tf_dtype)}; + VLOG(2) << "Added tensor: " << data.name + << DataTypeString(data.tensor.dtype()); + input_data_.push_back(data); + } + } + + // Adds test tensor (same as above) but with the default tf_dtype defined by + // the test params. + void AddTestTensor(const string& name, const std::vector& dims, + const std::vector& values = {}, + const std::vector& partial_input_shape_dims = {}) { + AddTestTensor(name, dims, tf_dtype, values, + partial_input_shape_dims); + } + + // Builds and runs the converted network. Checks output tensor shape. Tests + // output values using a matcher. The network can have multiple input and + // output tensors. The inputs are defined by the input_data_ member variable. + void BuildAndRun(const string& name, + const std::vector>& expected_output_dims, + const Status& expected_runtime_status, + const std::vector>>& matcher) { + TensorShape shape; + const int n_output = expected_output_dims.size(); + ASSERT_EQ(n_output, matcher.size()); + DataVec output_data; + for (int i = 0; i < n_output; i++) { + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); + string out_name = (n_output == 1) ? name : StrCat(name, ":", i); + InputOutputData data{out_name, + ConstructTensor(shape.num_elements(), 0, tf_dtype)}; + output_data.push_back(data); + } + ASSERT_FALSE(input_data_.empty()); + const int batch_size = input_data_[0].tensor.shape().dim_size(0); + Status stat = + OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size); + ASSERT_EQ(expected_runtime_status, stat); + if (expected_runtime_status.ok() && stat.ok()) { + for (int i = 0; i < n_output; i++) { + // Check the shape of the actual output tensors + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); + EXPECT_TRUE(output_data[i].tensor.shape() == shape) + << "Expected shape: " << shape.DebugString() << ", actual shape" + << output_data[i].tensor.shape().DebugString(); + EXPECT_THAT(GetDataAsFloat(output_data[i]), matcher[i]); + } + } + } + + // Runs validation and conversion. If conversion is successfull then builds + // the TRT network, executes it and checks the output. + void TestOpConverter(const string& name, const NodeDef node_def, + const std::vector& expected_output_dims, + const Status& expected_conversion_status, + const Status& expected_runtime_status, + const Matcher>& matcher) { + RunValidationAndConversion(node_def, expected_conversion_status, + name.c_str(), expected_output_dims); + if (expected_conversion_status.ok()) { + BuildAndRun(name, std::vector>({expected_output_dims}), + expected_runtime_status, + std::vector>>({matcher})); + } + } + + protected: + const TrtTestMode trt_mode; + const DataType tf_dtype; + const TrtPrecisionMode converter_precision; + DataVec input_data_; +}; + +// Op converter test in FP32 mode. While for debugging purposes it might make +// sense to run over all possible combinations, normally a subset of them +// would be sufficient: // - All valid options to TrtTestMode (implicit, explicit, dynamic shape) // - DataType: is the TF data type of the input tensors. This usually only // influences the data type added by Converter::AddInputTensor. We test the @@ -1687,66 +1853,22 @@ class ParameterizedOpConverterTest // how TRT handles the precision inside the TRT network, but should not matter // for the TF -> TRT conversion. Therefore it should be sufficient to test // for FP32. +class OpConverterTest1 : public ParameterizedOpConverterTestBase {}; + +// Instantiate parameter combinations to OpConverterTest1 INSTANTIATE_TEST_CASE_P( - OpConvTestInstantiation, ParameterizedOpConverterTest, + OpConvTestInstantiation, OpConverterTest1, ::testing::Combine(::testing::ValuesIn(ValidTrtModes), ::testing::Values(DT_FLOAT), ::testing::Values(TrtPrecisionMode::FP32))); -// Builds and runs the converted network. Checks output tensor shape. Tests -// output values using a matcher. -template -void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test, - const TestParamBase& p, - const std::vector& input_vec, - const Matcher>& matcher) { - if (!p.status.ok()) { - // conversion was not successful, we cannot run the network - return; - } - if (!p.runtime_status.ok()) { - // Runtime error is expected. This can happen if the operation is invalid - // for the actual input shape. Usually we catch these errors during - // conversion. If the network was defined with dynamic input shape than we - // have to postpone these steps until runtime. - // - // TODO(tfeher) Instead of early return, modify BuildAndRun to handle - // runtime errors. - return; - } - typedef typename EnumToDataType::Type T; - TensorShape shape; - TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape)); - const DataVec input_data{ - {"input", test->AsTensor(CastTestVector(input_vec), shape)}}; - DataVec output_data{{name, test->ConstructTensor(6)}}; - test->BuildAndRun(input_data, &output_data); - // Check the shape of the actual output tensor - TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape)); - EXPECT_TRUE(output_data[0].tensor.shape() == shape) - << "Expected shape: " << shape.DebugString() << ", actual shape" - << output_data[0].tensor.shape().DebugString(); - // Cast the output to float and compare to expected output - auto out_span = GetSpanForData(output_data[0]); - std::vector casted_output(out_span.begin(), out_span.end()); - EXPECT_THAT(casted_output, matcher); -} - -void InstantiateBuildAndRun(DataType tf_dtype, const string& name, - OpConverterTest* test, const TestParamBase& p, - const std::vector& input_vec, - const Matcher>& matcher) { - if (tf_dtype == DT_FLOAT) { - BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); - } else if (tf_dtype == DT_HALF) { - BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); - } else if (tf_dtype == DT_INT32) { - BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); - } else { - FAIL() << "Test not supported for " << tf_dtype; - } -} - +// Base class for tests that need to be tested for both FP32 and FP16. +class OpConverterTest2 : public ParameterizedOpConverterTestBase {}; +INSTANTIATE_TEST_CASE_P( + OpConvTestInstantiation, OpConverterTest2, + ::testing::Combine(::testing::ValuesIn(ValidTrtModes), + ::testing::Values(DT_FLOAT, DT_HALF), + ::testing::Values(TrtPrecisionMode::FP32))); template void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { out->Clear(); @@ -1884,14 +2006,7 @@ TEST_F(OpConverterTest, ConvertConst) { TestConvertConst(this); } -TEST_P(ParameterizedOpConverterTest, ConvertTranspose) { - const auto& spec = GetParam(); - const TrtTestMode trt_mode = std::get<0>(spec); - // Data type of TF input tensors - const DataType tf_dtype = std::get<1>(spec); - // Precision mode used for TensorRT engine - TrtPrecisionMode converter_precision = std::get<2>(spec); - +TEST_P(OpConverterTest1, ConvertTranspose) { // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); @@ -1902,7 +2017,7 @@ TEST_P(ParameterizedOpConverterTest, ConvertTranspose) { std::vector test_params = { // For the first test we leave param empty. This signals to use a // input as weight which will be invalid - TestParamBase{{1, 1, 2, 3}, + TestParamBase{{3, 1, 2, 1}, {}, {}, {}, @@ -1936,20 +2051,17 @@ TEST_P(ParameterizedOpConverterTest, ConvertTranspose) { std::vector expected_values{1, 4, 2, 5, 3, 6}; for (auto p : test_params) { SCOPED_TRACE(p); - Reset(converter_precision, trt_mode); - AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode, - &p.partial_input_dims); + Reset(); + AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, + p.partial_input_dims); if (p.param.empty()) { AddTestTensor("weights", {3}); } else { AddTestWeights("weights", {static_cast(p.param.size())}, p.param); } - RunValidationAndConversion(node_def, p.status, "my_transpose", - p.expected_output_dims); - InstantiateBuildAndRun(tf_dtype, "my_transpose", this, p, - {1, 2, 3, 4, 5, 6}, - ElementsAreArray(expected_values)); + TestOpConverter("my_transpose", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray(expected_values)); } } @@ -2046,7 +2158,7 @@ TEST_F(OpConverterTest, ConvertReshape) { const DataVec input_data{{"input", AsTensor(input_vec)}}; DataVec output_data{ {"my_reshape", ConstructTensor(input_vec.size())}}; - BuildAndRun(input_data, &output_data, batch_size); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data, batch_size)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(input_vec)); } @@ -2101,7 +2213,7 @@ void TestMatMulHelper( const DataVec input_data{{"input", test->AsTensor({0, 1})}}; DataVec output_data{{"my_matmul", test->ConstructTensor(2)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { @@ -2128,7 +2240,7 @@ void TestMatMulHelper( ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", test->AsTensor({0, 1})}}; DataVec output_data{{"my_matmul", test->ConstructTensor(2)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 3)); } else { @@ -2271,7 +2383,7 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) { ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", AsTensor({0, 1, 2, 3})}}; DataVec output_data{{"my_matmul", ConstructTensor(4)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); if (!transpose_a && !transpose_b) { EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(3, 4, 11, 16)); @@ -2291,91 +2403,70 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) { TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul"); } -template -void TestConvertBiasAdd(OpConverterTest* test) { +TEST_P(OpConverterTest2, ConvertBiasAdd) { + // Note that kINT32 is not supported by IScaleLayer, so we don't test + // DT_INT32 type here. DT_FLOAT and DT_HALF are tested. // Get the NodeDef for BiasAdd. - auto get_biasadd_nodedef = [](const string& data_format) -> NodeDef { + auto get_biasadd_nodedef = [](const string& data_format, + DataType tf_dtype) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), dtype); - auto weights = ops::Placeholder(s.WithOpName("weights"), dtype); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + auto weights = ops::Placeholder(s.WithOpName("weights"), tf_dtype); const auto biasadd_attrs = ops::BiasAdd::DataFormat(data_format); auto biasadd = ops::BiasAdd(s.WithOpName("my_biasadd"), input, weights, biasadd_attrs); return biasadd.operation.node()->def(); }; - typedef typename EnumToDataType::Type CType; for (const string& data_format : {"NHWC", "NCHW"}) { for (const int trt_input_rank : {1, 2, 3, 4}) { - test->Reset(); - NodeDef node_def = get_biasadd_nodedef(data_format); + Reset(); + NodeDef node_def = get_biasadd_nodedef(data_format, tf_dtype); // Add input, dims_array will be like {2, 1, ..., 1, 3} - std::vector dims_array(trt_input_rank, 1); + std::vector dims_array(trt_input_rank + 1, 1); if (trt_input_rank == 1) { - dims_array[0] = (data_format == "NHWC" ? 3 : 2); + dims_array[1] = (data_format == "NHWC" ? 3 : 2); } else { - dims_array[0] = 2; - dims_array[trt_input_rank - 1] = 3; + dims_array[1] = 2; + dims_array[trt_input_rank] = 3; } - test->AddTestTensor("input", dims_array, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - - // Add bias weights. - const int channel_size = (data_format == "NHWC" ? 3 : 2); - std::vector bias(channel_size); - for (int i = 0; i < channel_size; ++i) { - bias[i] = CType(i + 1); // bias will be {1, 2, 3, ...} - } - test->AddTestWeights("weights", {channel_size}, bias); - - // Run the conversion. - test->RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions()); - - // Build and run the engine. const int num_input = TrtTensorDimsNumElements(GetTestDims(dims_array)); ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2), num_input); + std::vector input_data(num_input, 0); + + AddTestTensor("input", dims_array, input_data); + + const int channel_size = (data_format == "NHWC" ? 3 : 2); + std::vector bias(channel_size); + for (int i = 0; i < channel_size; ++i) { + bias[i] = i + 1; // bias will be {1, 2, 3, ...} + } + AddTestWeights("weights", {channel_size}, bias, tf_dtype); + + // Build and run the engine. + std::vector output_data; - const DataVec input_data{ - {"input", test->ConstructTensor(num_input, CType(0))}}; - DataVec output_data{ - {"my_biasadd", test->ConstructTensor(num_input)}}; - test->BuildAndRun(input_data, &output_data); if (trt_input_rank == 1) { if (data_format == "NHWC") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(2), CType(3))); + output_data = {1, 2, 3}; } else { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(2))); + output_data = {1, 2}; } } else { if (data_format == "NHWC") { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(2), CType(3), CType(1), - CType(2), CType(3))); + output_data = {1, 2, 3, 1, 2, 3}; } else { - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(CType(1), CType(1), CType(1), CType(2), - CType(2), CType(2))); + output_data = {1, 1, 1, 2, 2, 2}; } } + TestOpConverter("my_biasadd", node_def, dims_array, Status::OK(), + Status::OK(), ElementsAreArray(output_data)); } } } -TEST_F(OpConverterTest, ConvertBiasAdd) { - // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test - // DT_INT32 type here. - TestConvertBiasAdd(this); - TestConvertBiasAdd(this); -} - template NodeDef GetBinaryOpNodeDef(const string& input_name_l, const string& input_name_r, DataType dtype) { @@ -2428,7 +2519,7 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun(input_data, &output_data, /*batch_size=*/2); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, /*batch_size=*/2)); if (node_def.op() == "Add") { EXPECT_THAT( GetSpanForData(output_data[0]), @@ -2562,7 +2653,7 @@ void TestAddN(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); DataVec output_data{{"my_addn", test->ConstructTensor(4)}}; - test->BuildAndRun(input_data, &output_data, /*batch_size=*/2); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, /*batch_size=*/2)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(CastTestVector({3, 6, 9, 12}))); } @@ -2586,7 +2677,7 @@ void TestAddN(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); DataVec output_data{{"my_addn", test->ConstructTensor(2)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(CastTestVector({5, 8}))); } @@ -2752,7 +2843,7 @@ void TestConvertSquare(OpConverterTest* test) { // Engine outputs are converted to FP16 automatically if we set FP16 mode in // the builder. DataVec output_data{{"my_square", test->ConstructTensor(num_inputs)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } @@ -2864,7 +2955,7 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { }; const DataVec input_data{{"boxes", AsTensor({0, 0, 0.3, 0.4})}, {"scores", AsTensor({0.4, 0.7, 0.3})}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4)); EXPECT_THAT(GetSpanForData(output_data[1]), ElementsAre(0.7, 0.4)); @@ -2874,90 +2965,67 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { } #endif // IS_TRT_VERSION_GE(5, 1, 0, 0) -TEST_F(OpConverterTest, ConvertActivation) { +template +NodeDef CreateUnaryOp(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + return T(s.WithOpName("my_unary"), input).operation.node()->def(); +} + +constexpr float kLeakyReluAlpha = 0.2f; +template <> +NodeDef CreateUnaryOp(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + return ops::internal::LeakyRelu( + s.WithOpName("my_unary"), input, + ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha)) + .operation.node() + ->def(); +} + +TEST_P(OpConverterTest1, ConvertActivation) { { // Input is weights, should fail. Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto relu = ops::Relu(s.WithOpName("my_act"), input); - const NodeDef& node_def = relu.operation.node()->def(); + const NodeDef& node_def = CreateUnaryOp(tf_dtype); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "The input \"input\" for Relu must be a tensor, at my_act"); + "The input \"input\" for Relu must be a tensor, at my_unary"); } - constexpr float kLeakyReluAlpha = 0.2f; constexpr float kSeluAlpha = 1.7580993408473768599402175208123f; constexpr float kSeluScale = 1.0507009873554804934193349852946f; + using OpFunc = std::function; + using ValFunc = float (*)(float); + std::map> op_map; - // Get nodedef for activation layer. - auto get_act_nodedef = [](string op_name) -> NodeDef { - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "LeakyRelu") { - auto act = ops::internal::LeakyRelu( - s.WithOpName("my_act"), input, - ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha)); - return act.operation.node()->def(); - } else if (op_name == "Relu") { - auto act = ops::Relu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Relu6") { - auto act = ops::Relu6(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Sigmoid") { - auto act = ops::Sigmoid(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Tanh") { - auto act = ops::Tanh(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Elu") { - auto act = ops::Elu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Selu") { - auto act = ops::Selu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Softsign") { - auto act = ops::Softsign(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Softplus") { - auto act = ops::Softplus(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } - EXPECT_TRUE(false); - return NodeDef(); - }; - // Get expected output for activation layer. - auto get_act_output = [](string op_name, float input) -> float { - if (op_name == "LeakyRelu") { - return (input > 0.0f) ? input : input * kLeakyReluAlpha; - } else if (op_name == "Relu") { - return (input > 0.0f) ? input : 0.0f; - } else if (op_name == "Relu6") { - return std::min(std::max(input, 0.0f), 6.0f); - } else if (op_name == "Sigmoid") { - return 1.0f / (1.0f + std::exp(-input)); - } else if (op_name == "Tanh") { - return std::tanh(input); - } else if (op_name == "Elu") { - return (input > 0.0f) ? input : std::exp(input) - 1; - } else if (op_name == "Selu") { - return (input > 0.0f) ? kSeluScale * input - : kSeluScale * kSeluAlpha * (std::exp(input) - 1); - } else if (op_name == "Softsign") { - return input / (std::abs(input) + 1); - } else if (op_name == "Softplus") { - return std::log(std::exp(input) + 1); - } - EXPECT_TRUE(false); - return 0; - }; +#define ADD_OP(name, op, compute) \ + op_map[name] = std::make_pair(CreateUnaryOp, compute) + ADD_OP("LeakyRelu", ops::internal::LeakyRelu, + [](float x) { return (x > 0.0f) ? x : x * kLeakyReluAlpha; }); + ADD_OP("Relu", ops::Relu, [](float x) { return (x > 0.0f) ? x : 0.0f; }); + ADD_OP("Relu6", ops::Relu6, + [](float x) { return std::min(std::max(x, 0.0f), 6.0f); }); + ADD_OP("Sigmoid", ops::Sigmoid, + [](float x) { return 1.0f / (1.0f + std::exp(-x)); }); + ADD_OP("Tanh", ops::Tanh, static_cast(std::tanh)); + ADD_OP("Elu", ops::Elu, + [](float x) { return (x > 0.0f) ? x : std::exp(x) - 1; }); + ADD_OP("Selu", ops::Selu, [](float x) { + return (x > 0.0f) ? kSeluScale * x + : kSeluScale * kSeluAlpha * (std::exp(x) - 1); + }); + ADD_OP("Softsign", ops::Softsign, + [](float x) { return x / (std::abs(x) + 1); }); + ADD_OP("Softplus", ops::Softplus, + [](float x) { return std::log(std::exp(x) + 1); }); +#undef ADD_OP // Get list of ops to test. std::vector ops_to_test; - // Add all ops supported by ConvertUnary. + // Add all ops supported by ConvertActivation. auto* map = ActivationTypeMap(); ops_to_test.reserve(map->size()); for (auto& pair : *map) { @@ -2966,16 +3034,30 @@ TEST_F(OpConverterTest, ConvertActivation) { // Add other activation ops to test. ops_to_test.push_back("Relu6"); ops_to_test.push_back("LeakyRelu"); + auto p = TestParamBase{ + {1, 1, 2, 3}, // input dims + {}, // input partial dims + {1, 1, 2, 3}, // expected output dims + }; // Ok. for (const string& op_name : ops_to_test) { + if (!op_map.count(op_name)) { + FAIL() << "Activation op test map does not contain op " << op_name; + } Reset(); - NodeDef node_def = get_act_nodedef(op_name); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); + NodeDef node_def = op_map[op_name].first(tf_dtype); + const std::vector input = {-100, -2, -1, 0, 1, 88}; + AddTestTensor("input", p.input_dims, input); + + // std::exp in Softplus will overflow for input > 88 + std::vector output_values; + std::transform(input.begin(), input.end(), + std::back_inserter(output_values), op_map[op_name].second); + TestOpConverter("my_unary", node_def, p.expected_output_dims, Status::OK(), + Status::OK(), ArrayFloatNear(output_values, 0, false)); + TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_act", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); // Certain activations should set quantization range automatically. auto ranges = quantization_ranges(); @@ -2985,17 +3067,6 @@ TEST_F(OpConverterTest, ConvertActivation) { op_name == "Softsign") { EXPECT_EQ(ranges[output.tensor()], 1.0f); } - - // std::exp in Softplus will overflow for input > 88 - const std::vector input = {-100, -2, -1, 0, 1, 88}; - const DataVec input_data{{"input", AsTensor(input)}}; - DataVec output_data{{"my_act", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - for (int i = 0; i < input.size(); i++) { - const float expected_output = get_act_output(op_name, input[i]); - EXPECT_FLOAT_EQ(GetSpanForData(output_data[0])[i], - expected_output); - } } } @@ -3095,23 +3166,17 @@ TEST_F(OpConverterTest, ConvertExpandDims) { const DataVec input_data{{"input", AsTensor({1, 2, 3, 4, 5, 6})}}; DataVec output_data{{"my_expanddims", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(1, 2, 3, 4, 5, 6)); } } -TEST_P(ParameterizedOpConverterTest, ConvertSqueeze) { - const auto& spec = GetParam(); - const TrtTestMode trt_mode = std::get<0>(spec); +TEST_P(OpConverterTest1, ConvertSqueeze) { const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch); - // Data type of TF input tensors - const DataType tf_dtype = std::get<1>(spec); - // Precision mode used for TensorRT engine - TrtPrecisionMode converter_precision = std::get<2>(spec); - // Get the NodeDef for Squeeze. - auto get_squeeze_nodedef = [tf_dtype](std::vector axes) -> NodeDef { + auto get_squeeze_nodedef = [](std::vector axes, + DataType tf_dtype) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); if (!axes.empty()) { @@ -3129,11 +3194,13 @@ TEST_P(ParameterizedOpConverterTest, ConvertSqueeze) { TestParamBase{ {1, 2, 1, 3}, // input dims {}, // input partial dims - {2, 1, 3}, // expected output dims + {2, 3}, // expected output dims {}, // axis - Status{ - error::UNIMPLEMENTED, - "Squeeze is only implemented for explicit dims, at my_squeeze"}}, + trt_mode == TrtTestMode::kExplicitBatch + ? Status::OK() + : Status{error::UNIMPLEMENTED, + "Squeeze is not implemented for empty squeeze_dims, at " + "my_squeeze"}}, TestParamBase{{1, 2, 1, 3}, {}, {2, 1, 3}, @@ -3202,14 +3269,12 @@ TEST_P(ParameterizedOpConverterTest, ConvertSqueeze) { for (TestParamBase p : test_params) { SCOPED_TRACE(p); - Reset(converter_precision, trt_mode); - NodeDef node_def = get_squeeze_nodedef(p.param); - AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode, - &p.partial_input_dims); - RunValidationAndConversion(node_def, p.status, "my_squeeze", - p.expected_output_dims); - InstantiateBuildAndRun(tf_dtype, "my_squeeze", this, p, {1, 2, 3, 4, 5, 6}, - ElementsAreArray({1, 2, 3, 4, 5, 6})); + Reset(); + NodeDef node_def = get_squeeze_nodedef(p.param, tf_dtype); + AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, + p.partial_input_dims); + TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray({1, 2, 3, 4, 5, 6})); } } @@ -3812,7 +3877,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { DataVec output_data{ {"my_strided_slice", ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -3952,21 +4017,22 @@ TEST_F(OpConverterTest, ConvertSlice) { const DataVec input_data{{"input", AsTensor({1, 2, 3, 4, 5, 6})}}; DataVec output_data{{"my_slice", ConstructTensor( ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } } -TEST_F(OpConverterTest, ConvertConv2D) { +TEST_P(OpConverterTest1, ConvertConv2D) { // Get nodedef for Conv2D layer. + DataType tf_type = tf_dtype; auto get_conv2d_nodedef = - [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", - string data_format = "NCHW", - std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + [tf_type](std::vector strides = {1, 1, 1, 1}, + string padding = "SAME", string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type); ops::Conv2D::Attrs attrs = ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, @@ -3988,7 +4054,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Filter is tensor, should fail. Reset(); NodeDef node_def = get_conv2d_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {3, 1, 2, 1}); AddTestTensor("weights", {3, 3, 1, 1}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -3998,7 +4064,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Filter is not 4D, should fail. Reset(); NodeDef node_def = get_conv2d_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4009,7 +4075,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4020,7 +4086,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Dilation rate must be 1 for batch and channel " @@ -4031,7 +4097,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2}); - AddTestTensor("input", {2, 3, 1}); + AddTestTensor("input", {1, 2, 3, 1}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Dilation rate must be 1 for batch and channel " @@ -4042,7 +4108,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4053,12 +4119,23 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Stride must be 1 for batch and channel dimensions, at my_conv2d"); } + if (trt_mode == TrtTestMode::kDynamicShape) { + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + // Channel dim unknown, should fail. + AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, + TfDataTypeToTrt(tf_type)); + AddTestWeights("weights", {1, 2, 1, 1}, {-1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Channel dimension must be static, at my_conv2d"); + } struct TestParams { std::vector input_dims; @@ -4076,7 +4153,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Ok. std::vector ok_params = { // Basic - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4084,10 +4161,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2}, + /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, // SAME padding (Asymmetric) - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4095,10 +4172,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 3}, + /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 1, -2, 0, 1, -4}}, // SAME padding (Symmetric) - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 3, 1, 1}, /*filter=*/{-1, 0, 1}, @@ -4106,10 +4183,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 3}, + /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, // NHWC - TestParams{/*input_dims=*/{2, 3, 1}, + TestParams{/*input_dims=*/{1, 2, 3, 1}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4117,10 +4194,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NHWC", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{2, 2, 1}, + /*expected_output_dims=*/{1, 2, 2, 1}, /*expected_output=*/{1, 1, 0, 1}}, // Dilated - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4128,10 +4205,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 2}, - /*expected_output_dims=*/{1, 2, 1}, + /*expected_output_dims=*/{1, 1, 2, 1}, /*expected_output=*/{2, 1}}, // Strided - TestParams{/*input_dims=*/{1, 2, 4}, + TestParams{/*input_dims=*/{1, 1, 2, 4}, /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4139,7 +4216,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2}, + /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, }; @@ -4148,23 +4225,22 @@ TEST_F(OpConverterTest, ConvertConv2D) { NodeDef node_def = get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, ok_params[i].dilations); - AddTestTensor("input", ok_params[i].input_dims); + std::vector partial_input_shape; + if (trt_mode == TrtTestMode::kDynamicShape) { + // The channel dim cannot have unknown size, fix that. + partial_input_shape.resize(ok_params[i].input_dims.size(), -1); + int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3; + partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id]; + } + + AddTestTensor("input", ok_params[i].input_dims, tf_dtype, + ok_params[i].input, partial_input_shape); AddTestWeights("weights", ok_params[i].filter_dims, ok_params[i].filter); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, - output.tensor()->getDimensions()); - const DataVec input_data{{"input", AsTensor(ok_params[i].input)}}; - DataVec output_data{ - {"my_conv2d", - ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAreArray(ok_params[i].expected_output)); + TestOpConverter("my_conv2d", node_def, ok_params[i].expected_output_dims, + Status::OK(), Status::OK(), + ElementsAreArray(ok_params[i].expected_output)); } } @@ -4289,7 +4365,7 @@ TEST_F(OpConverterTest, ConvertConv2DBackpropInput) { DataVec output_data{ {"my_conv2d_backprop_input", ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4621,7 +4697,7 @@ TEST_F(OpConverterTest, ConvertConv3D) { DataVec output_data{ {"my_conv3d", ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4808,7 +4884,7 @@ TEST_F(OpConverterTest, ConvertPool3D) { DataVec output_data{ {expected_node_name, ConstructTensor(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4853,7 +4929,7 @@ TEST_F(OpConverterTest, ConvertTopK) { {"input", AsTensor({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}}; DataVec output_data{{"my_topk", ConstructTensor(4)}, {"my_topk:1", ConstructTensor(4)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAre(6, 5, 7, 1)); EXPECT_THAT(GetSpanForData(output_data[1]), @@ -5040,8 +5116,8 @@ void TestConvertGather(OpConverterTest* test) { } DataVec output_data{ {"my_gather", test->ConstructTensor(expected_output.size())}}; - test->BuildAndRun(input_data, &output_data, - /*batch_size=*/expected_output_shape[0]); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, + /*batch_size=*/expected_output_shape[0])); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(converted_expected_output)); } @@ -5112,135 +5188,52 @@ TEST_F(OpConverterTest, ConvertGather) { TestConvertGather(this); } -TEST_F(OpConverterTest, ConvertUnary) { +NodeDef CreateCastOp(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF); + return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT) + .operation.node() + ->def(); +} + +TEST_P(OpConverterTest1, ConvertUnary) { { // Input is weights, should fail. Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto neg = ops::Neg(s.WithOpName("my_unary"), input); - const NodeDef& node_def = neg.operation.node()->def(); + const NodeDef node_def = CreateUnaryOp(tf_dtype); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Neg must be a tensor, at my_unary"); } - - // Get nodedef for unary layer. - auto get_unary_nodedef = [](string op_name) -> NodeDef { - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "Abs") { - auto unary = ops::Abs(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Acos") { - auto unary = ops::Acos(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Acosh") { - auto unary = ops::Acosh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Asin") { - auto unary = ops::Asin(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Asinh") { - auto unary = ops::Asinh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Atan") { - auto unary = ops::Atan(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Atanh") { - auto unary = ops::Atanh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Ceil") { - auto unary = ops::Ceil(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Cos") { - auto unary = ops::Cos(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Cosh") { - auto unary = ops::Cosh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Exp") { - auto unary = ops::Exp(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Floor") { - auto unary = ops::Floor(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Log") { - auto unary = ops::Log(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Neg") { - auto unary = ops::Neg(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Reciprocal") { - auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Rsqrt") { - auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sin") { - auto unary = ops::Sin(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sinh") { - auto unary = ops::Sinh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sqrt") { - auto unary = ops::Sqrt(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Tan") { - auto unary = ops::Tan(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } - EXPECT_TRUE(false); - return NodeDef(); - }; - // Get expected output for unary layer. - auto get_unary_output = [](string op_name, float input) -> float { - if (op_name == "Abs") { - return std::abs(input); - } else if (op_name == "Acos") { - return std::acos(input); - } else if (op_name == "Acosh") { - return std::acosh(input); - } else if (op_name == "Asin") { - return std::asin(input); - } else if (op_name == "Asinh") { - return std::asinh(input); - } else if (op_name == "Atan") { - return std::atan(input); - } else if (op_name == "Atanh") { - return std::atanh(input); - } else if (op_name == "Ceil") { - return std::ceil(input); - } else if (op_name == "Cos") { - return std::cos(input); - } else if (op_name == "Cosh") { - return std::cosh(input); - } else if (op_name == "Exp") { - return std::exp(input); - } else if (op_name == "Floor") { - return std::floor(input); - } else if (op_name == "Log") { - return std::log(input); - } else if (op_name == "Neg") { - return -input; - } else if (op_name == "Reciprocal") { - return 1.0 / input; - } else if (op_name == "Rsqrt") { - return 1.0 / std::sqrt(input); - } else if (op_name == "Sin") { - return std::sin(input); - } else if (op_name == "Sinh") { - return std::sinh(input); - } else if (op_name == "Sqrt") { - return std::sqrt(input); - } else if (op_name == "Tan") { - return std::tan(input); - } - EXPECT_TRUE(false); - return 0; - }; - + using OpFunc = std::function; + using ValFunc = float (*)(float); + std::map> op_map; +#define ADD_OP(name, op, compute) \ + op_map[name] = \ + std::make_pair(CreateUnaryOp, static_cast(compute)) + ADD_OP("Abs", ops::Abs, std::abs); + ADD_OP("Acos", ops::Acos, std::acos); + ADD_OP("Acosh", ops::Acosh, std::acosh); + ADD_OP("Asin", ops::Asin, std::asin); + ADD_OP("Asinh", ops::Asinh, std::asinh); + ADD_OP("Atan", ops::Atan, std::atan); + ADD_OP("Atanh", ops::Atanh, std::atanh); + op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; }); + ADD_OP("Ceil", ops::Ceil, std::ceil); + ADD_OP("Cos", ops::Cos, std::cos); + ADD_OP("Cosh", ops::Cosh, std::cosh); + ADD_OP("Exp", ops::Exp, std::exp); + ADD_OP("Floor", ops::Floor, std::floor); + ADD_OP("Log", ops::Log, std::log); + ADD_OP("Neg", ops::Neg, [](float x) { return -x; }); + ADD_OP("Reciprocal", ops::Reciprocal, [](float x) { return 1.0f / x; }); + ADD_OP("Rsqrt", ops::Rsqrt, [](float x) { return 1.0f / std::sqrt(x); }); + ADD_OP("Sin", ops::Sin, std::sin); + ADD_OP("Sinh", ops::Sinh, std::sinh); + ADD_OP("Sqrt", ops::Sqrt, std::sqrt); + ADD_OP("Tan", ops::Tan, std::tan); +#undef ADD_OP // Get list of ops to test. std::vector ops_to_test; // Add all ops supported by ConvertUnary. @@ -5251,26 +5244,35 @@ TEST_F(OpConverterTest, ConvertUnary) { } // Add other unary ops to test. ops_to_test.push_back("Rsqrt"); - // Ok. + // Prepare test parameters + auto p = TestParamBase{ + {1, 1, 2, 3}, // input dims + {}, // input partial dims + {1, 1, 2, 3}, // expected output dims + }; for (const string& op_name : ops_to_test) { + SCOPED_TRACE(op_name); Reset(); - NodeDef node_def = get_unary_nodedef(op_name); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); - - const std::vector input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; - const DataVec input_data{{"input", AsTensor(input)}}; - DataVec output_data{{"my_unary", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - for (int i = 0; i < input.size(); ++i) { - const float expected_output = get_unary_output(op_name, input[i]); - EXPECT_THAT(GetSpanForData(output_data[0])[i], - NanSensitiveFloatNear(expected_output, 0.0001)); + if (!op_map.count(op_name)) { + FAIL() << "Unary op test map does not contain op " << op_name; } + NodeDef node_def = op_map[op_name].first(tf_dtype); + + // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for + // now. Need to find a better way to express input and output types. + // + // TODO(tfeher): improve tests by defining an expected output data type and + // check that. Currently only the shape and values of the output are + // checked. + DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype; + + std::vector input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; + AddTestTensor("input", p.input_dims, input_tf_dtype, input_values); + std::vector output; + std::transform(input_values.begin(), input_values.end(), + std::back_inserter(output), op_map[op_name].second); + TestOpConverter("my_unary", node_def, p.expected_output_dims, Status::OK(), + p.runtime_status, ArrayFloatNear(output, 0.0001, true)); } } @@ -5374,7 +5376,7 @@ void TestConvertConcat(OpConverterTest* test) { DataVec output_data{ {"my_concat", test->ConstructTensor(ok_params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -5539,7 +5541,7 @@ void TestConvertSplit(OpConverterTest* test) { // Verify output values are correct. const DataVec input_data{ {"value", test->AsTensor(ok_params[i].value)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5716,7 +5718,7 @@ void TestConvertUnpack(OpConverterTest* test) { // Verify output values are correct. const DataVec input_data{ {"value", test->AsTensor(ok_params[i].value)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5885,7 +5887,7 @@ void TestConvertPack(OpConverterTest* test) { } DataVec output_data{{"my_pack", test->ConstructTensor( params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6033,7 +6035,7 @@ void TestConvertArgMinMax(OpConverterTest* test) { DataVec output_data{ {"my_arg", test->ConstructTensor( params[i].expected_argmax_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (node_def.op() == "ArgMax") { EXPECT_THAT(GetSpanForData(output_data[0]), @@ -6132,7 +6134,7 @@ void TestConvertDepthSpaceShuffle( DataVec input_data{{"input", test->AsTensor(params[i].input_value)}}; DataVec output_data{{"my_shuffle", test->ConstructTensor( params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6408,7 +6410,7 @@ void TestConvertClipByValue(OpConverterTest* test) { DataVec input_data{{"t", test->AsTensor(params[i].input_value)}}; DataVec output_data{{"my_clip", test->ConstructTensor( params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6516,7 +6518,7 @@ void TestConvertSquaredDifference(OpConverterTest* test) { DataVec output_data{ {"my_squared_diff", test->ConstructTensor(params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6621,7 +6623,7 @@ void TestConvertResize(OpConverterTest* test) { {"my_resize", test->ConstructTensor( params[i].expected_nearest_output_values.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (node_def.op() == "ResizeBilinear") { ExpectArrayAlmostEqual(params[i].expected_bilinear_output_values, @@ -6721,7 +6723,7 @@ void TestConvertPad(OpConverterTest* test) { {"my_pad", test->ConstructTensor( params[i].expected_output_values.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); ExpectArrayAlmostEqual(params[i].expected_output_values, GetSpanForData(output_data[0]), CType(1e-5)); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 6ab719db54d..72f4fe5ef9b 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -228,9 +228,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, << "This can result in poor performance."; } } - grappler::GraphProperties static_graph_properties(item); - TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - ConversionParams cp; if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) { VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " @@ -255,7 +252,9 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, } nodes_to_preserve.push_back(s); } - cp.input_graph_def = &item.graph; + + ConversionParams cp; + cp.grappler_item = &item; cp.output_names = &nodes_to_preserve; cp.trt_logger_name = trt_logger_name_; cp.max_batch_size = maximum_batch_size_; @@ -263,7 +262,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.output_graph_def = optimized_graph; cp.precision_mode = precision_mode_; cp.minimum_segment_size = minimum_segment_size_; - cp.graph_properties = &static_graph_properties; cp.cluster = cluster; cp.is_dyn_op = is_dynamic_op_; cp.max_cached_engines = max_cached_batches_; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index fb3ae6943d3..a4b64ec0dc5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace tensorrt { @@ -185,6 +186,40 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, return Status::OK(); } +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { + switch (tf_type) { + case DT_FLOAT: + *trt_type = nvinfer1::DataType::kFLOAT; + break; + case DT_HALF: + *trt_type = nvinfer1::DataType::kHALF; + break; + case DT_INT32: + *trt_type = nvinfer1::DataType::kINT32; + break; + default: + return errors::Internal("Unsupported tensorflow type"); + } + return Status::OK(); +} + +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { + switch (trt_type) { + case nvinfer1::DataType::kFLOAT: + *tf_type = DT_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *tf_type = DT_HALF; + break; + case nvinfer1::DataType::kINT32: + *tf_type = DT_INT32; + break; + default: + return errors::Internal("Invalid TRT type"); + } + return Status::OK(); +} + int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { int n_bindings = engine->getNbBindings(); int n_input = 0; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 5d4cf1bb851..43697573bbd 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -29,6 +29,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +static constexpr char kCastOutputTypeAttrName[] = "DstT"; + class IONamePrefixes { public: static constexpr const char* const kInputPHName = "TensorRTInputPH_"; @@ -106,6 +108,9 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, bool use_implicit_batch, int batch_size, TensorShape& shape); +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type); + // Returns a string that includes compile time TensorRT library version // information {Maj, Min, Patch}. string GetLinkedTensorRTVersion(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index a5332385994..37110442b26 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -81,7 +81,7 @@ tf_portable_proto_library( name = "portable_tf2xla_proto", config_string = "allow_all:true", header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"], - portable_deps = ["//tensorflow/core:portable_proto_lib_full_runtime"], + portable_deps = ["//tensorflow/core:portable_proto_lib"], proto_deps = [ ":tf2xla_proto", "//tensorflow/core:protos_all", @@ -182,6 +182,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", ], ) @@ -349,6 +350,7 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_cluster_util", @@ -703,12 +705,8 @@ cc_library( deps = [ "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", - "//tensorflow/compiler/mlir/tensorflow:device_util", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu", - "@com_google_absl//absl/container:flat_hash_set", + "//tensorflow/core:lib", "@llvm-project//llvm:support", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index 57278eea292..a9385e05564 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -49,10 +49,12 @@ typedef std::unordered_map NodeMap; // Each feed id identifies the positional output of some node, which may consist // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed // tensor with a placeholder. For each feed tensor, replaces all edges so they -// point from a new _Arg node instead. +// point from a new _Arg node instead. The newly created _Arg nodes are added to +// `arg_nodes`. Status AddArgNodes(Graph* graph, const NodeMap& node_map, const protobuf::RepeatedPtrField& feeds, - const std::unordered_map& feed_remapping) { + const std::unordered_map& feed_remapping, + std::unordered_set* arg_nodes) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const tf2xla::Feed& feed = feeds[arg_index]; // All feeds have been replaced by placeholders. @@ -86,6 +88,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, .Attr(kShapeAttr, TensorShape(feed.shape())) .Attr(kDebugNameAttr, feed.name()) .Finalize(graph, &arg_node)); + arg_nodes->insert(arg_node); // Collects out-edges from the feed node that have a matching edge index; // these will be replaced with edges from the arg node instead. @@ -149,13 +152,13 @@ Status RewriteAndPruneGraph( for (Node* n : graph->nodes()) { node_map[n->name()] = n; } + std::unordered_set nodes_to_keep; + TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed(), feed_remapping, + &nodes_to_keep)); TF_RETURN_IF_ERROR( - AddArgNodes(graph, node_map, config.feed(), feed_remapping)); - std::unordered_set retval_nodes; - TF_RETURN_IF_ERROR( - AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); + AddRetvalNodes(graph, node_map, config.fetch(), &nodes_to_keep)); VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph); - PruneForReverseReachability(graph, std::move(retval_nodes)); + PruneForReverseReachability(graph, std::move(nodes_to_keep)); FixupSourceAndSinkEdges(graph); VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. @@ -277,8 +280,16 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, // Prune the GraphDef first so that unknown ops that we aren't compiling get // filtered out. GraphDef second_copy_def; + // Add the placeholder nodes as "fetches" in prune_config, such that they will + // be preserved in PruneGraphDefInto. + auto prune_config = config; + for (const auto& entry : feed_remapping) { + auto ph = prune_config.add_fetch(); + *ph->mutable_id()->mutable_node_name() = entry.second; + ph->mutable_id()->set_output_index(0); + } TF_RETURN_IF_ERROR( - PruneGraphDefInto(config, first_copy_def, &second_copy_def)); + PruneGraphDefInto(prune_config, first_copy_def, &second_copy_def)); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( &second_copy_def, *g->op_registry(), /*node_offset=*/0)); diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index fb89742b139..c1f60abc0d6 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -106,8 +106,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel { errors::InvalidArgument( "Input must be a vector or matrix, but got shape ", input_tensor_shape.DebugString())); + const int dim0 = input_tensor_shape.dim_size(0); OP_REQUIRES( - ctx, input_tensor_shape.dim_size(0) == 4, + ctx, dim0 == 2 || dim0 == 4, errors::InvalidArgument( "First dimension of input must be of size 4, but got shape ", input_tensor_shape.DebugString())); @@ -118,10 +119,25 @@ class DataFormatVecPermuteOp : public XlaOpKernel { "Second dimension of 2D input must be of size 2, but got shape ", input_tensor_shape.DebugString())); } - int32 dst_indices[4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - if (src_format_[i] == dst_format_[j]) { + + string src_format_str = src_format_; + string dst_format_str = dst_format_; + if (dim0 == 2) { + // If the input is a vector of size 2, treat the two elements as spatial + // dimensions. + auto keep_only_spatial_dimensions = [](string* format_str) -> void { + auto new_end = std::remove_if( + format_str->begin(), format_str->end(), + [](const char dim) { return dim != 'H' && dim != 'W'; }); + format_str->erase(new_end, format_str->end()); + }; + keep_only_spatial_dimensions(&src_format_str); + keep_only_spatial_dimensions(&dst_format_str); + } + std::vector dst_indices(dim0); + for (int i = 0; i < dim0; ++i) { + for (int j = 0; j < dim0; ++j) { + if (src_format_str[i] == dst_format_str[j]) { dst_indices[j] = i; break; } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index bb2c0d9ddb8..5dbc083368c 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -28,6 +28,15 @@ limitations under the License. namespace tensorflow { namespace { +absl::InlinedVector SliceVector(xla::XlaOp input, int64 rank) { + absl::InlinedVector scalar_indices; + scalar_indices.reserve(rank); + for (int i = 0; i < rank; i++) + scalar_indices.push_back( + xla::Reshape(xla::Slice(input, {i}, {i + 1}, {1}), {})); + return scalar_indices; +} + class DynamicUpdateSliceOp : public XlaOpKernel { public: explicit DynamicUpdateSliceOp(OpKernelConstruction* context) @@ -41,21 +50,23 @@ class DynamicUpdateSliceOp : public XlaOpKernel { const TensorShape update_shape = ctx->InputShape("update"); const TensorShape index_shape = ctx->InputShape("indices"); + int64 rank = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsVector(index_shape) && - index_shape.num_elements() == input_shape.dims(), + index_shape.num_elements() == rank, errors::InvalidArgument("index must be a vector with length equal to " "the number of input dimensions")); OP_REQUIRES( - ctx, input_shape.dims() == update_shape.dims(), + ctx, rank == update_shape.dims(), errors::InvalidArgument("input and update must have the same rank," " input shape is ", input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); + xla::XlaOp indices = ctx->Input("indices"); xla::XlaOp result = xla::DynamicUpdateSlice( - ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); + ctx->Input("input"), ctx->Input("update"), SliceVector(indices, rank)); ctx->SetOutput(0, result); } }; @@ -76,17 +87,18 @@ class DynamicSliceOp : public XlaOpKernel { const TensorShape start_indices_shape = ctx->InputShape("start_indices"); const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + int64 rank = input_shape.dims(); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(start_indices_shape) && - start_indices_shape.num_elements() == input_shape.dims(), + start_indices_shape.num_elements() == rank, errors::InvalidArgument( "start_indices must be a vector with length equal to " "input rank, but input rank is ", - input_shape.dims(), " and start_indices has shape ", + rank, " and start_indices has shape ", start_indices_shape.DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(size_indices_shape) && - size_indices_shape.num_elements() == input_shape.dims(), + size_indices_shape.num_elements() == rank, errors::InvalidArgument( "size_indices must be a vector with length equal to " "input rank, but input rank is ", @@ -96,8 +108,10 @@ class DynamicSliceOp : public XlaOpKernel { std::vector size_indices; OP_REQUIRES_OK( ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + + xla::XlaOp start_indices = ctx->Input("start_indices"); xla::XlaOp result = xla::DynamicSlice( - ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->Input("input"), SliceVector(start_indices, rank), size_indices); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index a3fcb4d4b8f..bd6f58453df 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -81,9 +82,7 @@ class MatMulOp : public XlaOpKernel { b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, xla::Dot(lhs, rhs)); + ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 8431724f438..beb8e7aa174 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -36,10 +36,8 @@ limitations under the License. namespace tensorflow { namespace { -// TODO(phawkins): implement double-sized windowed reductions in XLA and remove -// the type constraint. -constexpr std::array kScanOpTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}}; +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32}}; class ScanOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 17d0b87edda..7f274c6b00f 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -42,19 +42,17 @@ class SliceOp : public XlaOpKernel { const TensorShape begin_tensor_shape = ctx->InputShape(1); const TensorShape size_tensor_shape = ctx->InputShape(2); + const int input_dims = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsVector(begin_tensor_shape) && TensorShapeUtils::IsVector(size_tensor_shape) && - begin_tensor_shape.num_elements() == input_shape.dims() && - size_tensor_shape.num_elements() == input_shape.dims(), + begin_tensor_shape.num_elements() == input_dims && + size_tensor_shape.num_elements() == input_dims, errors::InvalidArgument( "Expected begin and size arguments to be 1-D tensors of size ", - input_shape.dims(), ", but got shapes ", - begin_tensor_shape.DebugString(), " and ", - size_tensor_shape.DebugString(), " instead.")); - - const int input_dims = input_shape.dims(); + input_dims, ", but got shapes ", begin_tensor_shape.DebugString(), + " and ", size_tensor_shape.DebugString(), " instead.")); std::vector begin; std::vector size; @@ -129,7 +127,15 @@ class SliceOp : public XlaOpKernel { input_shape.dim_size(i), "], but ", "got ", size[i])); } - ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size)); + + absl::InlinedVector scalar_indices; + scalar_indices.reserve(input_dims); + xla::XlaOp begin = ctx->Input("begin"); + for (int i = 0; i < input_dims; i++) + scalar_indices.push_back( + xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {})); + + ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index fa5a96ca6bd..976ff91f6ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -136,8 +136,11 @@ class TensorListReserveOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); OP_REQUIRES( ctx, num_elements >= 0, - errors::InvalidArgument("XLA compilation requires a fixed tensor list " - "size. Set the number of elements.")); + errors::InvalidArgument( + "XLA compilation requires a fixed tensor list size. Set the number " + "of elements. This could also happen if you're using a TensorArray " + "in a while loop that does not have its maximum_iteration set, you " + "can fix this by setting maximum_iteration to a suitable value.")); // If element shape is compile time constant and it's not "unknown rank" // shape (-1), create an initialized TensorList. Otherwise create an @@ -197,10 +200,13 @@ class EmptyTensorListOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { int64 max_num_elements; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements)); - OP_REQUIRES( - ctx, max_num_elements >= 0, - errors::InvalidArgument("XLA compilation requires a fixed tensor list " - "size. Set the max number of elements.")); + OP_REQUIRES(ctx, max_num_elements >= 0, + errors::InvalidArgument( + "XLA compilation requires a fixed tensor list size. Set " + "the max number of elements. This could also happen if " + "you're using a TensorArray in a while loop that does not " + "have its maximum_iteration set, you can fix this by " + "setting maximum_iteration to a suitable value.")); if (dtype_ != DT_VARIANT) { // We are creating a non-nested TensorList. @@ -431,6 +437,120 @@ class TensorListStackOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp); +class TensorListConcatOp : public XlaOpKernel { + public: + explicit TensorListConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + + // Check that the TensorList is initialized. + bool is_initialized; + OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized))); + OP_REQUIRES(ctx, is_initialized, + errors::InvalidArgument("TensorList is not initialized")); + + // Only non-nested TensorList is supported for now. + bool is_nested; + OP_REQUIRES_OK(ctx, IsNestedTensorList(input, &is_nested)); + OP_REQUIRES(ctx, !is_nested, + errors::Unimplemented("Only non-nested TensorList is supported " + "for TensorListConcat.")); + + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(input, &buffer)); + + xla::XlaBuilder* b = input.builder(); + auto shape_or = b->GetShape(buffer); + OP_REQUIRES_OK(ctx, shape_or.status()); + xla::Shape element_shape = shape_or.ConsumeValueOrDie(); + std::vector element_dims = + xla::SpanToVector(element_shape.dimensions()); + OP_REQUIRES( + ctx, element_dims.size() > 1, + errors::Unimplemented("TensorList of scalars is not supported")); + int64 num_elements = element_dims[0]; + int64 tensor_lengths = element_dims[1]; + + std::vector new_dims = {num_elements * tensor_lengths}; + + for (int i = 2; i < element_dims.size(); i++) { + new_dims.push_back(element_dims[i]); + } + + xla::XlaOp out = xla::Reshape(buffer, new_dims); + ctx->SetOutput(0, out); + + // Second output is a tensor of lengths of returned tensors. + xla::XlaOp lengths = xla::ConstantR1(b, num_elements, tensor_lengths); + ctx->SetOutput(1, lengths); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorListConcatOp); +}; + +REGISTER_XLA_OP(Name("TensorListConcatV2"), TensorListConcatOp); + +class TensorListSplitOp : public XlaOpKernel { + public: + explicit TensorListSplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + // Only non-nested TensorList is supported for now. + OP_REQUIRES( + ctx, dtype_ != DT_VARIANT, + errors::Unimplemented( + "Only non-nested TensorList is supported for TensorListReserve.")); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input_tensor = ctx->Input(0); + + xla::XlaBuilder* b = input_tensor.builder(); + auto shape_or = b->GetShape(input_tensor); + OP_REQUIRES_OK(ctx, shape_or.status()); + xla::Shape element_shape = shape_or.ConsumeValueOrDie(); + std::vector element_dims = + xla::SpanToVector(element_shape.dimensions()); + OP_REQUIRES( + ctx, !element_dims.empty(), + errors::Unimplemented("Element dimensions have to be non-empty")); + + std::vector lengths; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths)); + OP_REQUIRES(ctx, !lengths.empty(), + errors::Unimplemented("Length has to be non-empty")); + int64 length = lengths[0]; + for (int64 len : lengths) { + OP_REQUIRES(ctx, len == length, + errors::Unimplemented("All lengths have to be the same")); + } + OP_REQUIRES( + ctx, element_dims[0] % length == 0, + errors::Unimplemented("Buffer size has to be a multiple of length")); + std::vector new_dims = {element_dims[0] / length, length}; + for (int i = 1; i < element_dims.size(); i++) { + new_dims.push_back(element_dims[i]); + } + + xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims); + + xla::XlaOp result; + OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result)); + ctx->SetTensorListOutput(0, result); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListSplitOp); +}; + +REGISTER_XLA_OP(Name("TensorListSplit") + .CompileTimeConstantInput("element_shape") + .CompileTimeConstantInput("lengths"), + TensorListSplitOp); + class TensorListFromTensorOp : public XlaOpKernel { public: explicit TensorListFromTensorOp(OpKernelConstruction* ctx) diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 499e27f0981..c398e5f129e 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -18,10 +18,18 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +auto* mlir_bridge_gauge_v1 = monitoring::Gauge::New( + "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v1", + "Tracks usage of the MLIR-based TF2XLA bridge among TF1 models"); +auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( + "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2", + "Tracks usage of the MLIR-based TF2XLA bridge among TF2 models"); + // This runs the first phase of the "bridge", transforming the graph in a form // that can be executed with delegation of some computations to an accelerator. // This builds on the model of XLA where a subset of the graph is encapsulated @@ -32,10 +40,12 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module) { if (!config_proto.experimental().enable_mlir_bridge()) { VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled"; + mlir_bridge_gauge_v2->GetCell()->Set(false); return Status::OK(); } VLOG(0) << "Running MLIR TPU Bridge"; + mlir_bridge_gauge_v2->GetCell()->Set(true); TF_RETURN_IF_ERROR( mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1))); @@ -48,10 +58,12 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, if (!options.session_options->config.experimental().enable_mlir_bridge()) { VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled"; + mlir_bridge_gauge_v1->GetCell()->Set(false); return Status::OK(); } VLOG(0) << "Running MLIR TPU Bridge V1 Compat"; + mlir_bridge_gauge_v1->GetCell()->Set(true); TF_RETURN_IF_ERROR( mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1))); diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index daf261fa5d8..43793be56a7 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -95,6 +96,7 @@ static void RegisterDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); + mlir::registerDialect(); return true; }(); (void)init_once; diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 24afe595b18..7ea69f734c9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -99,5 +99,42 @@ TEST(ConvertGraphDefToXla, Sum) { ConvertGraphDefToXla(graph_def, config, client, &computation))); } +TEST(ConvertGraphDefToXla, SumWithUnusedArgument) { + GraphDef graph_def = SumGraph(); + tf2xla::Config config = SumConfig(); + NodeDef* unused = graph_def.add_node(); + unused->set_name("unused"); + unused->set_op("Placeholder"); + (*unused->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + config.add_feed()->mutable_id()->set_node_name("unused"); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + // Set up arguments. + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); + auto unused_global_or = client->TransferToServer(y_literal); + TF_EXPECT_OK(x_global_or.status()); + TF_EXPECT_OK(y_global_or.status()); + TF_EXPECT_OK(unused_global_or.status()); + std::unique_ptr x_global = + std::move(x_global_or.ValueOrDie()); + std::unique_ptr y_global = + std::move(y_global_or.ValueOrDie()); + std::unique_ptr unused_global = + std::move(unused_global_or.ValueOrDie()); + + // Execute and check result. + auto result_or = client->ExecuteAndTransfer( + computation, {x_global.get(), y_global.get(), unused_global.get()}); + TF_EXPECT_OK(result_or.status()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3d6083621f4..1cf3e10b774 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/variant.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -571,6 +572,10 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); + bool is_inside_mustcompile = false; + TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr, + &is_inside_mustcompile); + // Performs a first function inlining pass before shape inference, since // otherwise shape inference can't see inside functions and a comprehensive // shape_map, including function ops, is needed to constant-propagate Shape @@ -622,6 +627,8 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { graph_optimizer_options.inline_multi_device_functions = true; graph_optimizer_options.inline_impl_selection_group_functions = true; graph_optimizer_options.inline_with_single_device_body_placer = true; + graph_optimizer_options.ignore_noinline = is_inside_mustcompile; + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), /*device=*/nullptr, &graph, graph_optimizer_options); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1350f9e3e0b..45f49cee328 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -17,7 +17,6 @@ package_group( "//tensorflow/compiler/...", "//tensorflow/python/tpu/...", "//third_party/py/jax/...", - "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) @@ -332,6 +331,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index cd52e2f5e45..404f9eb7519 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -70,6 +70,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_num_partitions( return *this; } +ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning( + bool use_spmd_partitioning) { + use_spmd_partitioning_ = use_spmd_partitioning; + return *this; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment( const DeviceAssignment& device_assignment) { device_assignment_ = device_assignment; diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 360ad0260df..9a7fdd974b1 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -77,6 +77,11 @@ class ExecutableBuildOptions { int num_partitions() const { return num_partitions_; } ExecutableBuildOptions& set_num_partitions(int num_partitions); + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning() const { return use_spmd_partitioning_; } + ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning); + // If set, this specifies a static device assignment for the computation. // Otherwise, the computation will be compiled generically and can be run with // any device assignment compatible with the computation's replica and @@ -104,6 +109,7 @@ class ExecutableBuildOptions { se::DeviceMemoryAllocator* device_allocator_ = nullptr; int num_replicas_ = 1; int num_partitions_ = 1; + bool use_spmd_partitioning_ = false; absl::optional device_assignment_; bool alias_passthrough_params_ = false; }; diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 9b8156efe5b..cb79b2ef7db 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -236,6 +236,19 @@ XLA_TEST_F(MathTest, SqrtF32) { ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } +XLA_TEST_F(MathTest, SqrtF64) { + XlaBuilder builder(TestName()); + Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F64); + + std::unique_ptr zero_data = + client_->TransferToServer(zero_literal).ConsumeValueOrDie(); + + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + Sqrt(zero); + + ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); +} + #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 XLA_TEST_F(MathTest, ErfInvF64) { XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index a779086f1d5..58365c0f498 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -860,34 +860,10 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index, }); } -XlaOp XlaBuilder::DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, - GetShapePtr(start_indices)); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferDynamicSliceShape( - *operand_shape, {*start_indices_shape}, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - for (int64 size : slice_sizes) { - instr.add_dynamic_slice_sizes(size); - } - - return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, - {operand, start_indices}); - }); -} - XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, @@ -898,43 +874,28 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( *operand_shape, start_indices_shapes, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - for (int64 size : slice_sizes) { - instr.add_dynamic_slice_sizes(size); - } - - std::vector operands = {operand}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + return DynamicSliceInternal(shape, operand, start_indices, slice_sizes); }); } -XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, - XlaOp start_indices) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; +StatusOr XlaBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); - TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, - GetShapePtr(start_indices)); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferDynamicUpdateSliceShape( - *operand_shape, *update_shape, {*start_indices_shape})); - *instr.mutable_shape() = shape.ToProto(); + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } - return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - {operand, update, start_indices}); - }); + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); } XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); std::vector start_indices_shape_ptrs; @@ -946,15 +907,22 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferDynamicUpdateSliceShape( *operand_shape, *update_shape, start_indices_shapes)); - *instr.mutable_shape() = shape.ToProto(); - - std::vector operands = {operand, update}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - operands); + return DynamicUpdateSliceInternal(shape, operand, update, start_indices); }); } +StatusOr XlaBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + operands); +} + XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1301,7 +1269,6 @@ XlaOp XlaBuilder::ConvGeneralDilated( int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_RETURN_IF_ERROR( @@ -1314,30 +1281,45 @@ XlaOp XlaBuilder::ConvGeneralDilated( window_dimensions[i] = rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } - TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + + TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferConvolveShape( - *lhs_shape, *rhs_shape, feature_group_count, - batch_group_count, instr.window(), dimension_numbers)); - *instr.mutable_shape() = shape.ToProto(); - - *instr.mutable_convolution_dimension_numbers() = dimension_numbers; - instr.set_feature_group_count(feature_group_count); - instr.set_batch_group_count(batch_group_count); - - if (precision_config != nullptr) { - *instr.mutable_precision_config() = *precision_config; - } - - return AddInstruction(std::move(instr), HloOpcode::kConvolution, - {lhs, rhs}); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferConvolveShape( + *lhs_shape, *rhs_shape, feature_group_count, + batch_group_count, window, dimension_numbers)); + return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides, + padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count, + batch_group_count, precision_config); }); } +StatusOr XlaBuilder::ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + *instr.mutable_window() = window; + *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + instr.set_feature_group_count(feature_group_count); + instr.set_batch_group_count(batch_group_count); + + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; + } + + return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); +} + XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type, const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -2203,6 +2185,39 @@ XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, }); } +XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferAllGatherShape( + *operand_shape, all_gather_dimension, shard_count)); + if (layout) { + *inferred_shape.mutable_layout() = *layout; + instr.set_constrain_layout(true); + } + *instr.mutable_shape() = inferred_shape.ToProto(); + + instr.add_dimensions(all_gather_dimension); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } + + TF_ASSIGN_OR_RETURN( + auto all_gather, + AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand})); + return all_gather; + }); +} + XlaOp XlaBuilder::CrossReplicaSum( XlaOp operand, absl::Span replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -3105,20 +3120,11 @@ XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index, stride, dimno); } -XlaOp DynamicSlice(const XlaOp operand, const XlaOp start_indices, - absl::Span slice_sizes) { - return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); -} XlaOp DynamicSlice(const XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } -XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, - const XlaOp start_indices) { - return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); -} - XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, absl::Span start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); @@ -3182,6 +3188,10 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, broadcast_dimensions, direction); } +XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) { + return Compare(lhs, rhs, {}, direction); +} + XlaOp Dot(const XlaOp lhs, const XlaOp rhs, const PrecisionConfig* precision_config) { return lhs.builder()->Dot(lhs, rhs, precision_config); @@ -3470,6 +3480,16 @@ XlaOp ReduceWindowWithGeneralPadding( base_dilations, window_dilations, padding); } +XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout) { + return operand.builder()->AllGather(operand, all_gather_dimension, + shard_count, replica_groups, channel_id, + layout); +} + XlaOp CrossReplicaSum(const XlaOp operand, absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 2ab4c575862..426b6d83207 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -421,16 +421,17 @@ class XlaBuilder { virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - ABSL_DEPRECATED("Use span-of-indices form instead") - XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); + virtual StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); - ABSL_DEPRECATED("Use span-of-indices form instead") - XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices); XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); + virtual StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); virtual StatusOr ConcatInDimInternal(const Shape& shape, @@ -491,6 +492,16 @@ class XlaBuilder { int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); + virtual StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); + XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span fft_length); @@ -549,6 +560,12 @@ class XlaBuilder { XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); + XlaOp AllGather( + XlaOp operand, int64 all_gather_dimension, int64 shard_count, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt, + const absl::optional& layout = absl::nullopt); + XlaOp AllReduce( XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, @@ -842,14 +859,10 @@ class XlaBuilder { friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - friend XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); friend XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); - friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - XlaOp start_indices); friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); @@ -876,6 +889,7 @@ class XlaBuilder { friend XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); friend XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, @@ -992,6 +1006,11 @@ class XlaBuilder { absl::Span> padding); friend XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups); + friend XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout); friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const absl::optional& channel_id, @@ -1417,10 +1436,6 @@ XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); -ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); - // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. // The shape of 'update' determines the shape of the slice of 'operand' @@ -1441,9 +1456,6 @@ XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); -ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices); - // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, @@ -1487,10 +1499,12 @@ XlaOp Lt(XlaOp lhs, XlaOp rhs, XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); -// Enqueues a comparison instruction onto the computation. +// Enqueues a comparison instruction onto the computation (optionally without +// broadcast_dimensions for consistency with others). XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); +XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); // Enqueues a dot instruction onto the computation. XlaOp Dot(XlaOp lhs, XlaOp rhs, @@ -1771,6 +1785,11 @@ XlaOp ReduceWindowWithGeneralPadding( XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); +XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, int64 shard_count, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt, + const absl::optional& layout = absl::nullopt); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then // broadcasting the reduction result to those cores. The reduction function is diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 1fa839b2014..4fa47077fca 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -381,6 +381,29 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllGatherR1) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); + AllGather(x, /*all_gather_dimension=*/0, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {16}))); +} + +TEST_F(XlaBuilderTest, AllGatherR2) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllGather(x, /*all_gather_dimension=*/1, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); +} + TEST_F(XlaBuilderTest, AllToAll) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h index 3ccbfb28bd0..6a3b17a154a 100644 --- a/tensorflow/compiler/xla/client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -29,8 +29,8 @@ namespace xla { class XlaComputation { public: XlaComputation() : unique_id_(-1) {} - XlaComputation(const HloModuleProto& proto) - : unique_id_(proto.id()), proto_(proto) {} + XlaComputation(HloModuleProto proto) + : unique_id_(proto.id()), proto_(std::move(proto)) {} ~XlaComputation() {} diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 216fb0a7422..958629c5fa6 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -55,18 +55,23 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // b/77879207. opts.set_xla_gpu_disable_multi_streaming(true); - // TODO(jlebar): Disable fastmath once doing so is not a performance - // regression. + // Disable forms of fast math that have caused users problems in the past. opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_cpu_fast_math_honor_nans(true); + opts.set_xla_cpu_fast_math_honor_infs(true); + opts.set_xla_cpu_fast_math_honor_functions(true); + opts.set_xla_cpu_fast_math_honor_division(true); + + // By default, copy TF's Eigen style min_max behavior with nans. + opts.set_xla_cpu_enable_fast_min_max(false); + opts.set_xla_gpu_enable_fast_min_max(true); opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); opts.set_xla_gpu_deterministic_reductions(false); opts.set_xla_cpu_enable_xprof_traceme(true); - // TODO(b/155295372): disable ptxas fallback by default. - opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(true); - opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_error(true); + opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(false); return opts; } @@ -261,6 +266,12 @@ static void AllocateFlags() { "When xla_cpu_enable_fast_math is true then this controls whether we " "forbid to approximate calculations for functions. Ignored when " "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max), + flag_values->xla_cpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that always propagates " + "NaNs.")); flag_objects->push_back(tensorflow::Flag( "xla_gpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), @@ -554,15 +565,6 @@ static void AllocateFlags() { "that falling back to the driver can have drawbacks like using more " "memory and/or other bugs during compilation, so we recommend setting " "this flag to false.")); - flag_objects->push_back(tensorflow::Flag( - "xla_gpu_unsafe_fallback_to_driver_on_ptxas_error", - bool_setter_for( - &DebugOptions::set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_error), - flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_error(), - "If true, XLA GPU falls back to the driver if there is an error when " - "running ptxas. Note that falling back to the driver can have drawbacks " - "like using more memory and/or other bugs during compilation, so we " - "recommend setting this flag to false.")); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 43ee0fdd820..8ae8c418d5d 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -50,6 +50,7 @@ class RunId { public: // Creates a new, unique RunId. RunId(); + explicit RunId(int64 value) : data_(value) {} RunId(const RunId&) = default; RunId& operator=(const RunId&) = default; diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md index b7868fedb8b..60bde306266 100644 --- a/tensorflow/compiler/xla/g3doc/index.md +++ b/tensorflow/compiler/xla/g3doc/index.md @@ -174,9 +174,33 @@ When filing bugs, attach the contents of the `/tmp/generated` directory If possible, try to isolate a bug to a single XLA program by using the -[`replay_computation`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/replay_computation.cc) +[`replay_computation`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/run_hlo_module_main.cc) and iteratively running it on generated programs. +## Known Issues + +Compilation with XLA can greatly improve the performance of your programs, but +the TensorFlow interop has a number of known sharp corners. + +### TensorArray TF/XLA Interconversion + +The problem manifests itself as an error message +`Support for TensorList crossing the XLA/TF boundary is not implemented`. + +XLA supports `tf.TensorArray`. However, the _interconversion_ between TF and +XLA representations is not implemented yet. +This error often arises when the `TensorArray` is used inside the compiled +block, but the derivative is taken outside. + +Workaround: compile the outermost scope which is taking the derivative. + +### Random Number Generation + +XLA currently ignores TF seeds to random operations. This affects stateful TF +random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will +behave as if the compilation was seeded with a new unique seed at each run. This +limitation does not apply to stateless random ops. + ## XLA Frontends Apart from TensorFlow, XLA programs can be generated by: diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 495701eaac2..002d07184a7 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -2299,20 +2299,26 @@ The output is guaranteed to be a deterministic function of the initial state but it is *not* guaranteed to be deterministic between backends and different compiler versions. -`RngBitGenerator(algorithm, key, shape)` | Arguments | Type | Semantics | -|---------------- | ----------------- | ------------------------------------- | -| `algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. | | -`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. | | `shape` | -`Shape` | Output shape for generated data. | +`RngBitGenerator(algorithm, key, shape)` -Available values for `algorithm`: * `rng_default`: Backend specific algorithm -with backend specific shape requirements. * `rng_three_fry`: ThreeFry -counter-based PRNG algorithm. The `initial_state` shape is `u64[2]` with -arbitrary values. -[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) -* `rng_philox`: Philox algorithm to generate random numbers in parallel. The -`initial_state` shape is `u64[3]` with arbitrary values. -[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) +Arguments | Type | Semantics +--------------- | ----------------- | ------------------------------------- +`algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. +`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. +`shape` | `Shape` | Output shape for generated data. + +Available values for `algorithm`: + +- `rng_default`: Backend specific algorithm with backend specific shape + requirements. + +- `rng_three_fry`: ThreeFry counter-based PRNG algorithm. The `initial_state` + shape is `u64[2]` with arbitrary values. + [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + +- `rng_philox`: Philox algorithm to generate random numbers in parallel. The + `initial_state` shape is `u64[3]` with arbitrary values. + [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) ## Scatter diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index cbbad741ce3..73c37d6b2f3 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -2104,6 +2104,32 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, root_piece_->set_subshape(shape_.get()); } +MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span src_buf_ptrs, + const Shape& shape) + : MutableLiteralBase() { + shape_ = absl::make_unique(shape); + if (!shape_->IsTuple()) { + CHECK_EQ(src_buf_ptrs.size(), 1); + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast(src_buf_ptrs[0])); + root_piece_->set_subshape(shape_.get()); + } else { + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + Piece child_piece; + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(src_shape.IsArray()); + child_piece.set_subshape(&src_shape); + child_piece.set_buffer(src_buf_ptrs[i]); + root_piece_->emplace_back(std::move(child_piece)); + } + } +} + MutableBorrowingLiteral::~MutableBorrowingLiteral() { if (root_piece_ != nullptr) { delete root_piece_; diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 1553d042e80..a2be92fbf5b 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -776,6 +776,10 @@ class MutableBorrowingLiteral : public MutableLiteralBase { const ShapeIndex& view_root); MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Create a literal from a list of buffers and a shape. + // Returns a tuple literal if `shape` is a tuple type. + MutableBorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape); + private: // Recursively copies the subtree from the `src_piece` at the given child // index to the `dest_piece`. For buffers only the pointers are copied, but diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 8c6bc84cf8e..10737489331 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") @@ -186,6 +187,89 @@ cc_library( ], ) +cc_library( + name = "ops", + srcs = ["ops.cc"], + hdrs = ["ops.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:svd", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@pybind11", + ], +) + +cc_library( + name = "outfeed_receiver", + srcs = ["outfeed_receiver.cc"], + hdrs = ["outfeed_receiver.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "cpu_outfeed_receiver_test", + size = "small", + srcs = ["outfeed_receiver_test.cc"], + deps = [ + ":outfeed_receiver", + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/pjrt:cpu_device", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "outfeed_receiver_py", + srcs = ["outfeed_receiver_py.cc"], + hdrs = ["outfeed_receiver_py.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":outfeed_receiver", + ":types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@pybind11", + ], +) + config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -205,7 +289,9 @@ pybind_extension( deps = [ ":bfloat16", ":dlpack", + ":ops", ":python_ref_manager", + ":outfeed_receiver_py", ":types", "@com_google_absl//absl/base", "@com_google_absl//absl/hash", @@ -228,12 +314,6 @@ pybind_extension( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:comparators", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:qr", - "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", - "//tensorflow/compiler/xla/client/lib:sorting", - "//tensorflow/compiler/xla/client/lib:svd", "//tensorflow/compiler/xla/pjrt:cpu_device", "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", "//tensorflow/compiler/xla/pjrt:pjrt_client", @@ -260,8 +340,8 @@ pybind_extension( "//tensorflow/core:lib_internal_impl", # buildcleaner: keep "//tensorflow/core/profiler/lib:profiler_backends", "//tensorflow/core/profiler/lib:profiler_session", - "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/python/profiler/internal:traceme_wrapper", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:platform", ] + select({ diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc new file mode 100644 index 00000000000..89891d39f78 --- /dev/null +++ b/tensorflow/compiler/xla/python/ops.cc @@ -0,0 +1,356 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ops.h" + +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "pybind11/attr.h" +#include "pybind11/pybind11.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/svd.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace py = pybind11; + +void BuildOpsSubmodule(py::module* m) { + // ops submodule, containing free functions that add operators to an + // XlaBuilder. + py::module ops = m->def_submodule("ops", "XLA operations"); + + py::enum_( + ops, "TriangularSolveOptions_Transpose") + .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) + .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) + .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) + .value("ADJOINT", TriangularSolveOptions::ADJOINT); + + ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); + ops.def( + "AllReduce", + static_cast, + const absl::optional&, const absl::optional&)>( + &AllReduce), + py::arg("operand"), py::arg("computation"), + py::arg("replica_groups") = py::list(), + py::arg("channel_id") = absl::nullopt, + py::arg("shape_with_layout") = absl::nullopt); + ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), + py::arg("concat_dimension"), py::arg("split_count"), + py::arg("replica_groups") = py::list(), + py::arg("layout") = absl::nullopt); + ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), + py::arg("source_target_pairs")); + ops.def("CreateToken", &CreateToken, py::arg("builder")); + ops.def("CrossReplicaSum", + static_cast)>( + &CrossReplicaSum), + py::arg("operand"), py::arg("replica_groups") = py::list()); + ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), + py::arg("new_element_type")); + ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); + ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), + py::arg("shape"), py::arg("broadcast_dimensions")); + ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), + py::arg("operands")); + ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); + ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); + ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); + ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), + py::arg("dimension")); + ops.def("Conditional", + static_cast, + absl::Span)>(&Conditional), + py::arg("branch_index"), py::arg("branch_computations"), + py::arg("branch_operands")); + ops.def("Conditional", + static_cast(&Conditional), + py::arg("predicate"), py::arg("true_operand"), + py::arg("true_computation"), py::arg("false_operand"), + py::arg("false_computation")); + ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); + ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), + py::arg("literal")); + ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), + py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), + py::arg("lhs_dilation"), py::arg("rhs_dilation"), + py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, + py::arg("batch_group_count") = 1, + py::arg("precision_config") = nullptr); + ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), + py::arg("new_element_type")); + ops.def( + "CustomCall", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span operands, const Shape& shape, + const py::bytes& opaque) -> XlaOp { + return CustomCall(builder, call_target_name, operands, shape, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape"), py::arg("opaque") = py::bytes("")); + ops.def( + "CustomCallWithLayout", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const py::bytes& opaque) -> XlaOp { + return CustomCallWithLayout(builder, call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), + py::arg("opaque") = py::bytes("")); + ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), + py::arg("precision_config") = nullptr); + ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), + py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); + ops.def("DynamicSlice", + static_cast, + absl::Span)>(&DynamicSlice), + py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); + ops.def("DynamicUpdateSlice", + static_cast)>( + &DynamicUpdateSlice), + py::arg("operand"), py::arg("update"), py::arg("start_indices")); + + ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), + py::arg("fft_length")); + + ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), + py::arg("dimension_numbers"), py::arg("slice_sizes"), + py::arg("indices_are_sorted") = false); + ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), + py::arg("index")); + ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), + py::arg("shape"), py::arg("config") = ""); + ops.def("Iota", + static_cast(&Iota), + py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); + ops.def("Iota", + static_cast(&Iota), + py::arg("builder"), py::arg("type"), py::arg("size")); + ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), + py::arg("computation"), py::arg("dimensions"), + py::arg("static_operands") = py::list()); + ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); + ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), + py::arg("token"), py::arg("shape_with_layout"), + py::arg("outfeed_config") = ""); + ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), + py::arg("padding_config")); + ops.def("Parameter", + static_cast&)>( + &Parameter), + py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), + py::arg("name") = "", + py::arg("replicated_at_leaf_buffers") = std::vector()); + ops.def( + "QR", + [](XlaOp a, bool full_matrices) -> StatusOr> { + TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); + return std::make_pair(qr.q, qr.r); + }, + py::arg("operand"), py::arg("full_matrices")); + ops.def( + "Eigh", + [](XlaOp a, bool lower, int64 max_iter, + float epsilon) -> std::pair { + auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); + return std::make_pair(eigh.v, eigh.w); + }, + py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, + py::arg("epsilon") = 1e-6); + ops.def( + "SVD", + [](XlaOp a, int64 max_iter, + float epsilon) -> std::tuple { + auto svd = SVD(a, max_iter, epsilon); + return std::make_tuple(svd.u, svd.d, svd.v); + }, + py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); + ops.def("Reduce", + static_cast, + absl::Span, const XlaComputation&, + absl::Span)>(&Reduce), + py::arg("builder"), py::arg("operands"), py::arg("init_values"), + py::arg("computation"), py::arg("dimensions_to_reduce")); + ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), + py::arg("exponent_bits"), py::arg("mantissa_bits")); + ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding, + py::arg("operand"), py::arg("init_value"), py::arg("computation"), + py::arg("window_dimensions"), py::arg("window_strides"), + py::arg("base_dilations"), py::arg("window_dilations"), + py::arg("padding")); + ops.def("ReplicaId", &ReplicaId, py::arg("builder")); + ops.def("Reshape", + static_cast, + absl::Span)>(&Reshape), + py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); + ops.def("Reshape", + static_cast)>(&Reshape), + py::arg("operand"), py::arg("new_sizes")); + ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); + ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), + py::arg("shape")); + ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), + py::arg("shape")); + ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"), + py::arg("updates"), py::arg("update_computation"), + py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false, + py::arg("unique_indices") = false); + ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), + py::arg("on_false")); + ops.def("SelectAndScatterWithGeneralPadding", + &SelectAndScatterWithGeneralPadding, py::arg("operand"), + py::arg("select"), py::arg("window_dimensions"), + py::arg("window_strides"), py::arg("padding"), py::arg("source"), + py::arg("init_value"), py::arg("scatter")); + ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), + py::arg("limit_indices"), py::arg("strides")); + ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), + py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); + ops.def( + "Sort", + [](XlaBuilder* builder, absl::Span operands, + absl::optional comparator, int64 dimension, + bool is_stable) -> XlaOp { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + std::vector operand_types; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); + operand_types.push_back(operand_shape.element_type()); + } + + if (comparator) { + return Sort(operands, **comparator, dimension, is_stable); + } else { + return Sort(operands, + CreateScalarLtComputation(operand_types, builder), + dimension, is_stable); + } + }); + }, + py::arg("builder"), py::arg("operands"), + py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1, + py::arg("is_stable") = false); + ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); + ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); + ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), + py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), + py::arg("transpose_a")); + ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); + ops.def("While", &While, py::arg("condition"), py::arg("body"), + py::arg("init")); + + ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); + ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); + ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); + ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); + ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), + py::arg("b"), py::arg("x")); + +#define BINARY_OP(op) \ + ops.def( \ + #op, \ + [](XlaOp a, XlaOp b, absl::optional> dims) { \ + return dims ? op(a, b, *dims) : op(a, b); \ + }, \ + py::arg("lhs"), py::arg("rhs"), \ + py::arg("broadcast_dimensions") = absl::nullopt) + BINARY_OP(Eq); + BINARY_OP(Ne); + BINARY_OP(Ge); + BINARY_OP(Gt); + BINARY_OP(Lt); + BINARY_OP(Le); + BINARY_OP(Add); + BINARY_OP(Sub); + BINARY_OP(Mul); + BINARY_OP(Div); + BINARY_OP(Rem); + BINARY_OP(Max); + BINARY_OP(Min); + BINARY_OP(And); + BINARY_OP(Or); + BINARY_OP(Xor); + BINARY_OP(ShiftLeft); + BINARY_OP(ShiftRightArithmetic); + BINARY_OP(ShiftRightLogical); + BINARY_OP(Atan2); + BINARY_OP(Pow); + BINARY_OP(Complex); +#undef BINARY_OP + +#define UNARY_OP(op) ops.def(#op, &op) + UNARY_OP(Not); + UNARY_OP(PopulationCount); + UNARY_OP(Clz); + UNARY_OP(Abs); + UNARY_OP(Exp); + UNARY_OP(Expm1); + UNARY_OP(Floor); + UNARY_OP(Ceil); + UNARY_OP(Round); + UNARY_OP(Log); + UNARY_OP(Log1p); + UNARY_OP(Sign); + UNARY_OP(Cos); + UNARY_OP(Sin); + UNARY_OP(Tanh); + UNARY_OP(IsFinite); + UNARY_OP(Neg); + UNARY_OP(Sqrt); + UNARY_OP(Rsqrt); + UNARY_OP(Square); + UNARY_OP(Reciprocal); + UNARY_OP(Erfc); + UNARY_OP(Erf); + UNARY_OP(ErfInv); + UNARY_OP(Lgamma); + UNARY_OP(Digamma); + UNARY_OP(BesselI0e); + UNARY_OP(BesselI1e); + UNARY_OP(Acos); + UNARY_OP(Asin); + UNARY_OP(Atan); + UNARY_OP(Tan); + UNARY_OP(Acosh); + UNARY_OP(Asinh); + UNARY_OP(Atanh); + UNARY_OP(Cosh); + UNARY_OP(Sinh); + UNARY_OP(Real); + UNARY_OP(Imag); + UNARY_OP(Conj); +#undef UNARY_OP +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ops.h b/tensorflow/compiler/xla/python/ops.h new file mode 100644 index 00000000000..7fe34e941ba --- /dev/null +++ b/tensorflow/compiler/xla/python/ops.h @@ -0,0 +1,27 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ + +#include "pybind11/pybind11.h" + +namespace xla { + +void BuildOpsSubmodule(pybind11::module* m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc new file mode 100644 index 00000000000..0be4167c397 --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -0,0 +1,492 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/outfeed_receiver.h" + +#include + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +// Implementation notes: +// +// Startup: +// ------- +// +// The startup is initiated by a call from Python to StartOutfeedReceiver, +// which starts N threads for listening to the N devices and for enqueueing +// the received data into a callback queue. There is one additional callback +// thread for dequeing the data and invoking the Python callback. +// +// Framing protocol +// ---------------- +// +// The outfeed mechanism has a single channel and the receiver must know +// exactly the shape and number of outfeed operations issued by the compiled +// code. This makes it hard to use outfeed in conditionals and loops and +// especially when outfeeding different-shaped data. +// +// To address this, when we compile the code we capture the shape of the +// data being outfed, and we generate a consumer ID (uint32_t) that is unique +// across the lifetime of the program to: the Python callable to callback to, +// the shape of the arguments, the keyword arguments to pass to the callable. +// Each outfeed payload is preceeded by a header (of shape u32[2]) with a +// special first value and the consumer ID. We maintain a registry of shapes +// by consumer ID. When receiving we lookup the shape by consumer ID, and then +// we read the payload. +// +// Back pressure: +// -------------- +// +// We maintain a sum of the bytes from all the data waiting in the callback +// queue. The listening threads will wait for the sum to drop below a +// configurable threshold, default 256Mb. While the listening thread is waiting, +// on CPU and GPU the next outfeed operation from the device will block. On +// TPU there is a buffer, but eventually the TPU will also block. +// +// Shutdown: +// --------- +// +// The shutdown is initiated automatically when the last reference to the +// outfeed receiver object is dropped, and the Python garbage collector invokes +// the destructor. +// +// The shutdown sequence is implemented as follows: +// * we enqueue on all devices a computation that outfeeds a special header +// with customer ID kOutfeedCidShutdown. +// * when each listening threads gets the shutdown header, it decrements +// a counter of listening threads, and if the counter reaches 0, it +// enqueues a special shutdown callback. +// * when the callback thread gets the shutdown callback marker, it terminates. +// * the shutdown code waits until all threads terminate. +// +// Since we currently keep the shape registry in the OutfeedReceiver, it is +// not safe to replace the OutfeedReceiver instance during the lifetime of +// the JAX program, or else previously cached jitted computations may refer +// to previously cached shapes. This can be solved, but for now we disallow +// replacing the OutfeedReceiver, and do not provide a Shutdown API to the +// Python program. + +namespace xla { + +// The header contains: +// 0. kOutfeedHeaderStart +// 1. consumer id +int constexpr kOutfeedHeaderWords = 2; +uint32_t constexpr kOutfeedHeaderStart = 271828; +// Special consumer IDs, without outfeed payload. +uint32_t constexpr kOutfeedCidShutdown = 0; + +// A Device and its PjRtClient. +struct DeviceWithClient { + Device* device; + std::shared_ptr client; +}; + +// Encapsulates data received from a device outfeed. +class OutfeedData { + public: + OutfeedData(DeviceWithClient device_client, uint32_t consumer_id, Shape shape) + : device_client_(device_client), + consumer_id_(consumer_id), + shape_(shape), + literal_(nullptr), + literal_size_bytes_(0) {} + + DeviceWithClient device_client() { return device_client_; } + uint32_t consumer_id() const { return consumer_id_; } + Shape shape() const { return shape_; } + std::unique_ptr literal() { + CHECK(literal_); + return std::move(literal_); + } + + void SetLiteral(std::unique_ptr literal); + + ssize_t literal_size_bytes() const { return literal_size_bytes_; } + + std::string DebugString() const; + + private: + DeviceWithClient device_client_; + uint32_t consumer_id_; + Shape shape_; + std::unique_ptr literal_; + ssize_t literal_size_bytes_; +}; + +void OutfeedData::SetLiteral(std::unique_ptr literal) { + literal_ = std::move(literal); + shape_ = literal_->shape(); + int total_size_bytes = 0; + ShapeUtil::ForEachSubshape( + shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) { + if (!literal_subshape.IsTuple()) { + total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8); + } + }); + literal_size_bytes_ = total_size_bytes; +} + +std::string OutfeedData::DebugString() const { + return absl::StrFormat("dev=%s; cons=%d; shape=%s", + device_client_.device->DebugString(), consumer_id_, + shape_.ToString()); +} + +class OutfeedReceiverImpl { + public: + OutfeedReceiverImpl(OutfeedReceiver::Callback callback, + std::vector> clients, + ssize_t max_callback_queue_size_bytes); + + OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete; + OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete; + + // Blocks until all data has been received from devices and all data + // in the queue has been passed to Python. + ~OutfeedReceiverImpl(); + + void Start(); + + StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, + std::vector arrays); + + private: + bool CallbackQueueNotEmpty() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return !callback_queue_.empty(); + } + + bool CallbackQueueHasSpace() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return callback_queue_size_bytes_ < max_callback_queue_size_bytes_; + } + + bool ShutdownDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0); + } + + void CallbackThreadLoop(); + void DeviceListenerThreadLoop(int device_idx); + + // Enqueues to a device an outfeed operation with a shutdown consumer ID. + Status SendShutdownOutfeedHeader(int device_idx); + + // Receives a raw Literal from a device outfeed. + StatusOr> ReceiveRawFromOutfeed(const Device* device, + const Shape& shape); + + // Enqueues received data in the callbaback queue. + void EnqueueReceivedData(std::unique_ptr received) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Shuts down the threads. See implementation notes at top of file. + // It is not safe to restart an OutfeedReceiver after shutting down one. + void Shutdown(); + + OutfeedReceiver::Callback callback_; + // The devices on which we are listening, with their clients. + std::vector devices_; + // Maximum bytes capacity of the callback queue. + uint64_t max_callback_queue_size_bytes_; + + absl::Mutex mu_; + // Registered shapes by consumer id. + // The shape registry must be alive as long as the program exists. + // Right now we tell the user to never restart after Shutdown. + absl::flat_hash_map shape_registry_ TF_GUARDED_BY(mu_); + // How many bytes of Literal are in the callback queue. + uint64_t callback_queue_size_bytes_ TF_GUARDED_BY(mu_); + // Threads listening. + int num_listening_threads_ TF_GUARDED_BY(mu_); + bool shutdown_started_ TF_GUARDED_BY(mu_); + + // How many callback threads are still working. Used for shutdown. + int num_working_callback_threads_ TF_GUARDED_BY(mu_); + + std::queue> callback_queue_ TF_GUARDED_BY(mu_); + // The threadpool must come last to ensure the queue exists + // when the pool destructor is called. + std::unique_ptr threads_; +}; + +OutfeedReceiverImpl::OutfeedReceiverImpl( + OutfeedReceiver::Callback callback, + std::vector> clients, + ssize_t max_callback_queue_size_bytes) { + callback_ = callback; + max_callback_queue_size_bytes_ = max_callback_queue_size_bytes; + for (const auto& client : clients) { + for (const auto& device : client->devices()) { + devices_.push_back(DeviceWithClient{device.get(), client}); + } + } + CHECK_GT(devices_.size(), 0); + + callback_queue_size_bytes_ = 0; + num_listening_threads_ = 0; + num_working_callback_threads_ = 0; + shutdown_started_ = false; +} + +void OutfeedReceiverImpl::Start() { + { + absl::MutexLock lock(&mu_); + CHECK(!shutdown_started_); + } + int num_threads = 1 + devices_.size(); + threads_ = absl::make_unique( + tensorflow::Env::Default(), "outfeed_receiver", num_threads); + threads_->Schedule([this]() { CallbackThreadLoop(); }); + for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) { + threads_->Schedule( + [this, device_idx]() { DeviceListenerThreadLoop(device_idx); }); + } +} + +void OutfeedReceiverImpl::Shutdown() { + VLOG(2) << "Shutdown start"; + { + absl::MutexLock lock(&mu_); + CHECK(!shutdown_started_); + shutdown_started_ = true; + } + for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) { + CHECK(SendShutdownOutfeedHeader(device_idx).ok()); + } + VLOG(2) << "Shutdown waiting for listening and callback threads to stop"; + absl::MutexLock lock(&mu_); + mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone)); + VLOG(2) << "Shutdown done"; +} + +OutfeedReceiverImpl::~OutfeedReceiverImpl() { + VLOG(2) << "~OutfeedReceiverImpl"; + Shutdown(); +} + +void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) { + { + absl::MutexLock lock(&mu_); + ++num_listening_threads_; + } + DeviceWithClient device_client = devices_[device_idx]; + while (true) { + Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}); + std::unique_ptr header = + ReceiveRawFromOutfeed(device_client.device, header_shape).ValueOrDie(); + absl::Span header_data = header->data(); + CHECK_EQ(header_data.size(), kOutfeedHeaderWords); + CHECK_EQ(header_data[0], kOutfeedHeaderStart); + uint32_t consumer_id = header_data[1]; + Shape shape; + { + absl::MutexLock lock(&mu_); + auto registered_shape = shape_registry_.find(consumer_id); + if (registered_shape == shape_registry_.end()) { + LOG(FATAL) + << "[" << device_client.device->DebugString() + << "] Cannot find registered shape for consumer ID " << consumer_id + << ". Perhaps the code was compiled with a different instance " + << "of OutfeedReceiver."; + } + shape = registered_shape->second; + } + auto received = + absl::make_unique(device_client, consumer_id, shape); + VLOG(2) << "Listener received header " << received->DebugString(); + if (consumer_id == kOutfeedCidShutdown) { + VLOG(2) << "[" << device_client.device->DebugString() + << "] Listener received shutdown header"; + absl::MutexLock lock(&mu_); + --num_listening_threads_; + if (num_listening_threads_ == 0) { + VLOG(2) << "Last listener shutdown; enqueue shutdown callback"; + EnqueueReceivedData(std::move(received)); + } + return; + } + std::unique_ptr data = + ReceiveRawFromOutfeed(device_client.device, shape).ValueOrDie(); + received->SetLiteral(std::move(data)); + absl::MutexLock lock(&mu_); + EnqueueReceivedData(std::move(received)); + } +} + +void OutfeedReceiverImpl::EnqueueReceivedData( + std::unique_ptr received) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace)); + ssize_t literal_size_bytes = received->literal_size_bytes(); + callback_queue_size_bytes_ += literal_size_bytes; + VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size " + << literal_size_bytes << " bytes; " << (1 + callback_queue_.size()) + << " callbacks in queue of total size " << callback_queue_size_bytes_ + << " bytes.\n"; + callback_queue_.push(std::move(received)); +} + +StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( + const Device* device, const Shape& shape) { + std::shared_ptr literal_shared; + + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + TF_ASSIGN_OR_RETURN(Literal literal, + local_device->client()->TransferFromOutfeedLocal( + shape, local_device->device_ordinal())); + + return absl::make_unique(std::move(literal)); +} + +void OutfeedReceiverImpl::CallbackThreadLoop() { + { + absl::MutexLock lock(&mu_); + num_working_callback_threads_++; + CHECK_EQ(num_working_callback_threads_, 1); + } + while (true) { + std::unique_ptr received; + { + absl::MutexLock lock(&mu_); + mu_.Await( + absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueNotEmpty)); + received = std::move(callback_queue_.front()); + callback_queue_.pop(); + callback_queue_size_bytes_ -= received->literal_size_bytes(); + VLOG(2) << "Dequeued callback for " << received->DebugString() << "; " + << callback_queue_.size() << " callbacks in queue of total size " + << callback_queue_size_bytes_ << " bytes.\n"; + } + if (received->consumer_id() == kOutfeedCidShutdown) { + VLOG(2) << "Callback loop received shutdown signal"; + { + absl::MutexLock lock(&mu_); + CHECK(callback_queue_.empty()); + CHECK_EQ(callback_queue_size_bytes_, 0); + --num_working_callback_threads_; + } + VLOG(2) << "Callback loop done"; + return; + } + { + tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback"); + DeviceWithClient device_client = received->device_client(); + callback_(device_client.device, std::move(device_client.client), + received->consumer_id(), received->literal()); + } + } +} + +Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { + const Device* device = devices_[device_idx].device; + constexpr int consumer_id = kOutfeedCidShutdown; + VLOG(2) << "[" << device->DebugString() + << "] SendSpecialHeader cons=" << consumer_id; + XlaBuilder builder( + absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx)); + XlaOp send = + AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {}) + .ValueOrDie(); + XlaComputation computation = builder.Build(send).ValueOrDie(); + + CompileOptions compile_options; + compile_options.executable_build_options.set_num_replicas(1); + compile_options.executable_build_options.set_num_partitions(1); + DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = device->id(); + compile_options.executable_build_options.set_device_assignment( + device_assignment); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + PjRtExecutable::Compile(computation, devices_[device_idx].client.get(), + std::move(compile_options))); + ExecuteOptions execute_options; + TF_ASSIGN_OR_RETURN(std::vector> output_buffers, + executable->Execute({}, execute_options)); + return Status::OK(); +} + +StatusOr OutfeedReceiverImpl::AddOutfeedToBuilder( + XlaBuilder* builder, XlaOp token, uint32_t consumer_id, + std::vector arrays) { + XlaOp data = Tuple(builder, std::move(arrays)); + Shape shape_with_layout = builder->GetShape(data).ValueOrDie(); + ShapeUtil::ForEachMutableSubshape( + &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + VLOG(2) << "RegisterShape cons=" << consumer_id + << "; shape=" << shape_with_layout.ToString(); + { + absl::MutexLock lock(&mu_); + auto found = shape_registry_.find(consumer_id); + if (found != shape_registry_.end()) { + if (!ShapeUtil::Equal(shape_with_layout, found->second)) { + return InvalidArgument( + "Shape %s does not match previous shape %s used " + "for consumer id %d", + shape_with_layout.DebugString(), found->second.DebugString(), + consumer_id); + } + } else { + shape_registry_.insert({consumer_id, shape_with_layout}); + } + } + + std::vector header{kOutfeedHeaderStart, consumer_id}; + XlaOp header_op = ConstantR1(builder, header); + token = OutfeedWithToken( + header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), ""); + if (consumer_id != kOutfeedCidShutdown) { + token = OutfeedWithToken(data, token, shape_with_layout, ""); + } + return token; +} + +OutfeedReceiver::OutfeedReceiver( + Callback callback, std::vector> clients, + ssize_t max_callback_queue_size_bytes) { + p_impl_ = absl::make_unique( + callback, std::move(clients), max_callback_queue_size_bytes); +} + +OutfeedReceiver::~OutfeedReceiver() {} + +void OutfeedReceiver::Start() { p_impl_->Start(); } + +StatusOr OutfeedReceiver::AddOutfeedToBuilder( + XlaBuilder* builder, XlaOp token, uint32_t consumer_id, + std::vector arrays) { + if (consumer_id == kOutfeedCidShutdown) { + return InvalidArgument("Consumer ID cannot be a reserved value: %d", + consumer_id); + } + return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.h b/tensorflow/compiler/xla/python/outfeed_receiver.h new file mode 100644 index 00000000000..a0fdfcd36f0 --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver.h @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +class OutfeedReceiverImpl; + +// Implements a multithreaded receiver of outfeeds from devices. +class OutfeedReceiver { + public: + // A callback takes: device, client (for the device), consumer id, received. + // The client pointer should be alive while the device is used. + using Callback = std::function, + uint32_t, std::shared_ptr)>; + + // Constructs the receiver for the given clients and callback function. + // + // Args: + // callback: a function to be called when an outfeed is ready for + // processing. + // clients: the clients for whose devices to listen. + // max_callback_queue_size_bytes: the maximum number of bytes for all + // received outfeeds queued to be processed. When this limit is reached + // we pause receiving outfeeds from devices. + OutfeedReceiver(Callback callback, + std::vector> clients, + ssize_t max_callback_queue_size_bytes); + + OutfeedReceiver(const OutfeedReceiver&) = delete; + OutfeedReceiver& operator=(const OutfeedReceiver&) = delete; + + // Blocks until all data has been received from devices and all data + // in the queue has been passed to Python. + ~OutfeedReceiver(); + + // Starts the listener threads and the callback thread. + void Start(); + + // Adds to the computation builder the outfeed of the arrays. + // Has the side-effect of registering the sent shape for the consumer_id. + // Returns error status if the outfeed shape is different than the + // previously used shape for the same consumer_id or the consumer id is + // invalid. + StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, + std::vector arrays); + + private: + std::unique_ptr p_impl_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_py.cc b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc new file mode 100644 index 00000000000..a6256cfe86c --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/outfeed_receiver_py.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/outfeed_receiver.h" +#include "tensorflow/compiler/xla/python/types.h" + +namespace xla { + +namespace py = pybind11; + +namespace { + +// A wrapper for OutfeedReceiver for use from Python, useful for ensuring +// that the GIL is released before destroying the OutfeedReceiver. +class OutfeedReceiverForPython { + public: + // A callback to Python takes: consumer id, received literal. + using CallbackToPython = + std::function, uint32_t, pybind11::object)>; + + OutfeedReceiverForPython(CallbackToPython callback_python, + std::vector> clients, + ssize_t max_callback_queue_size_bytes) { + callback_python_ = callback_python; + outfeed_receiver_shutting_down_ = false; + OutfeedReceiver::Callback callback = + [this](Device* device, std::shared_ptr client, + uint32_t consumer_id, std::shared_ptr literal) { + this->Callback(device, client, consumer_id, literal); + }; + outfeed_receiver_ = absl::make_unique( + callback, std::move(clients), max_callback_queue_size_bytes); + } + OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete; + OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete; + + ~OutfeedReceiverForPython() { + // This destructor is called from the Python GC. Release it for the duration + // of the destruction, including the destruction of the OutfeedReceiver, + // when we may actually have to wait for threads to end. During this time + // we do not callback to Python (sometimes we get an exception + // "std::runtime_error: scoped_acquire::dec_ref(): thread state must + // be current!""). + { + absl::MutexLock lock(&mu_); + outfeed_receiver_shutting_down_ = true; + } + py::gil_scoped_release gil_release; + outfeed_receiver_ = nullptr; // Shutdown the outfeed receiver. + } + + void Start() { outfeed_receiver_->Start(); } + + StatusOr AddOutfeed(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, std::vector arrays) { + return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id, + arrays); + } + + void Callback(Device* device, std::shared_ptr client, + uint32_t consumer_id, std::shared_ptr literal) { + { + absl::MutexLock lock(&mu_); + if (outfeed_receiver_shutting_down_) { + VLOG(2) << "Ignoring unsafe callback to Python during shutdown"; + return; + } + } + py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython + py::object literal_python = + LiteralToPython(std::move(literal)).ValueOrDie(); + // The callback_ should handle all exceptions in user-code. If we get + // an exception here, it is a bug in the callback and we should stop. + callback_python_(WrapWithClient(std::move(client), device), + consumer_id, std::move(literal_python)); + } + + private: + CallbackToPython callback_python_; + absl::Mutex mu_; + bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_); + std::unique_ptr outfeed_receiver_; +}; + +} // namespace + +void BuildOutfeedReceiverSubmodule(py::module* m) { + py::module outfeed_receiver = + m->def_submodule("outfeed_receiver", "Outfeed receiver"); + outfeed_receiver.def( + "start", + [](OutfeedReceiverForPython::CallbackToPython callback_to_python, + std::vector> clients, + ssize_t max_callback_queue_size_bytes) + -> std::unique_ptr { + auto server = absl::make_unique( + callback_to_python, clients, max_callback_queue_size_bytes); + server->Start(); + return server; + }, + py::arg("callback_to_python"), py::arg("backends"), + py::arg("max_queue_size_bytes") = 256 * 1024 * 1024, + R"(Starts a multithreaded outfeed receiver. + + There is one thread for each of the specified devices. When Python + drops the last reference to the returned object, the receiver is shut + down. The destructor will block until all data is received from + devices. + + Args: + * callback_to_python: a Python callback to call, with + and the data received. + * backends: the list of backends to listen on. + * max_queue_size_bytes: an optional integer to bound the maximum size + of arrays in the callback queue. When this limit is reached the + device listener pauses. + )", + py::call_guard()); + + py::class_ outfeed_receiver_class( + outfeed_receiver, "OutfeedReceiverForPython"); + + outfeed_receiver_class.def( + "add_outfeed", &OutfeedReceiverForPython::AddOutfeed, py::arg("builder"), + py::arg("token"), py::arg("consumer_id"), py::arg("arrays"), + R"(Adds an outfeed into the given computation builder. + + Has the side-effect of registering the sent shape along with the consumer + ID. Returns error if the outfeed shape is not compatible with previously + used shape for the same consumer ID.)", + py::call_guard()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_py.h b/tensorflow/compiler/xla/python/outfeed_receiver_py.h new file mode 100644 index 00000000000..6b1a712327a --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver_py.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ + +#include "pybind11/pybind11.h" + +namespace xla { + +void BuildOutfeedReceiverSubmodule(pybind11::module* m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc new file mode 100644 index 00000000000..ea84b4e18d6 --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -0,0 +1,258 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/outfeed_receiver.h" + +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { + +namespace { + +Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id, + PjRtClient* client) { + XlaComputation computation = builder->Build(root).ValueOrDie(); + + CompileOptions compile_options; + compile_options.executable_build_options.set_num_replicas(1); + compile_options.executable_build_options.set_num_partitions(1); + DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = device_id; + compile_options.executable_build_options.set_device_assignment( + device_assignment); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + PjRtExecutable::Compile(computation, client, std::move(compile_options))); + ExecuteOptions execute_options; + TF_ASSIGN_OR_RETURN(std::vector> output_buffers, + executable->Execute({}, execute_options)); + return Status::OK(); +} + +// Accumulates the received data. +class Accumulator { + public: + struct Data { + uint32_t consumer_id; + std::shared_ptr data; + }; + + void Receive(uint32_t consumer_id, std::shared_ptr data) { + absl::MutexLock lock(&mutex_); + received_.push_back(Data{consumer_id, data}); + } + + std::vector received() { + absl::MutexLock lock(&mutex_); + return received_; + } + + private: + absl::Mutex mutex_; + std::vector received_ TF_GUARDED_BY(mutex_); +}; + +TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, + GetCpuClient(true)); + std::vector> clients{cpu_client}; + + auto receiver = absl::make_unique(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr client, + uint32_t consumer_id, std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data = Iota(&builder, shape0, 0); + XlaOp send = outfeed_receiver + ->AddOutfeedToBuilder(&builder, CreateToken(&builder), + consumer_id0, {data}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok()); + + // Shutdown the receiver, to force it to wait to deliver the callbacks. + outfeed_receiver = nullptr; + std::vector received = receiver->received(); + EXPECT_EQ(1, received.size()); + EXPECT_EQ(consumer_id0, received[0].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); +} + +TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, + GetCpuClient(true)); + std::vector> clients{cpu_client}; + + auto receiver = absl::make_unique(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr client, + uint32_t consumer_id, std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder0("execute_test_outfeed_0"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder0, shape0, 0); + XlaOp send0 = outfeed_receiver + ->AddOutfeedToBuilder(&builder0, CreateToken(&builder0), + consumer_id0, {data0}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder0, send0, 0, cpu_client.get()).ok()); + + XlaBuilder builder1("execute_test_outfeed_1"); + constexpr int consumer_id1 = 6; + const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); + XlaOp data1 = Iota(&builder1, shape1, 0); + XlaOp send1 = outfeed_receiver + ->AddOutfeedToBuilder(&builder1, CreateToken(&builder1), + consumer_id1, {data1}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder1, send1, 0, cpu_client.get()).ok()); + + // Shutdown the receiver, to force it to wait to deliver the callbacks. + outfeed_receiver = nullptr; + std::vector received = receiver->received(); + EXPECT_EQ(2, received.size()); + EXPECT_EQ(consumer_id0, received[0].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); + EXPECT_EQ(consumer_id1, received[1].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape1}), received[1].data->shape()); +} + +TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, + GetCpuClient(true)); + std::vector> clients{cpu_client}; + + auto receiver = absl::make_unique(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr client, + uint32_t consumer_id, std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder, shape0, 0); + XlaOp send0 = outfeed_receiver + ->AddOutfeedToBuilder(&builder, CreateToken(&builder), + consumer_id0, {data0}) + .ValueOrDie(); + + constexpr int consumer_id1 = 6; + const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); + XlaOp data1 = Iota(&builder, shape1, 0); + XlaOp send1 = + outfeed_receiver + ->AddOutfeedToBuilder(&builder, send0, consumer_id1, {data1}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder, send1, 0, cpu_client.get()).ok()); + + // Shutdown the receiver, to force it to wait to deliver the callbacks. + outfeed_receiver = nullptr; + std::vector received = receiver->received(); + EXPECT_EQ(2, received.size()); + EXPECT_EQ(consumer_id0, received[0].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); + EXPECT_EQ(consumer_id1, received[1].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape1}), received[1].data->shape()); +} + +TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, + GetCpuClient(true)); + std::vector> clients{cpu_client}; + + auto receiver = absl::make_unique(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr client, + uint32_t consumer_id, std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder, shape0, 0); + XlaOp send0 = outfeed_receiver + ->AddOutfeedToBuilder(&builder, CreateToken(&builder), + consumer_id0, {data0}) + .ValueOrDie(); + + const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); + XlaOp data1 = Iota(&builder, shape1, 0); + // A different shape for the same consumer ID. + StatusOr send1 = outfeed_receiver->AddOutfeedToBuilder( + &builder, send0, consumer_id0, {data1}); + EXPECT_FALSE(send1.ok()); + EXPECT_THAT(send1.status().ToString(), + testing::HasSubstr("does not match previous shape element_type")); +} + +TEST(OutfeedReceiverTest, InvalidConsumerIdError) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, + GetCpuClient(true)); + std::vector> clients{cpu_client}; + + auto receiver = absl::make_unique(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr client, + uint32_t consumer_id, std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder, shape0, 0); + StatusOr send0 = outfeed_receiver->AddOutfeedToBuilder( + &builder, CreateToken(&builder), 0, {data0}); + + EXPECT_FALSE(send0.ok()); + EXPECT_THAT(send0.status().ToString(), + testing::HasSubstr("Consumer ID cannot be a reserved value")); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index ef0caff0ae6..6d4482af43f 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -20,6 +20,9 @@ from __future__ import print_function from absl import logging +# Import xla_client to load shared C++ extensions (just CompileOptions at the +# time of writing). +from tensorflow.compiler.xla.python import xla_client # pylint: disable=unused-import from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index f03595bf677..0b6824e83e9 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -24,17 +24,12 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "pybind11/attr.h" #include "pybind11/cast.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/comparators.h" -#include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/qr.h" -#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" -#include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/lib/svd.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -47,6 +42,8 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" +#include "tensorflow/compiler/xla/python/ops.h" +#include "tensorflow/compiler/xla/python/outfeed_receiver_py.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -62,15 +59,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" +#include "tensorflow/python/profiler/internal/traceme_wrapper.h" #include "tensorflow/stream_executor/platform.h" namespace xla { +namespace { namespace py = pybind11; -namespace { +using ::tensorflow::profiler::TraceMeWrapper; struct Uniquer { absl::Mutex mu; @@ -304,358 +302,6 @@ StatusOr PjRtBufferCudaArrayInterface(const PjRtBuffer& buffer) { return result; } -void BuildOpsSubmodule(py::module* m) { - // ops submodule, containing free functions that add operators to an - // XlaBuilder. - py::module ops = m->def_submodule("ops", "XLA operations"); - - py::enum_( - ops, "TriangularSolveOptions_Transpose") - .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) - .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) - .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) - .value("ADJOINT", TriangularSolveOptions::ADJOINT); - - ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); - ops.def( - "AllReduce", - static_cast, - const absl::optional&, const absl::optional&)>( - &AllReduce), - py::arg("operand"), py::arg("computation"), - py::arg("replica_groups") = py::list(), - py::arg("channel_id") = absl::nullopt, - py::arg("shape_with_layout") = absl::nullopt); - ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), - py::arg("concat_dimension"), py::arg("split_count"), - py::arg("replica_groups") = py::list(), - py::arg("layout") = absl::nullopt); - ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), - py::arg("source_target_pairs")); - ops.def("CreateToken", &CreateToken, py::arg("builder")); - ops.def("CrossReplicaSum", - static_cast)>( - &CrossReplicaSum), - py::arg("operand"), py::arg("replica_groups") = py::list()); - ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), - py::arg("new_element_type")); - ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); - ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), - py::arg("shape"), py::arg("broadcast_dimensions")); - ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), - py::arg("operands")); - ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); - ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); - ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); - ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), - py::arg("dimension")); - ops.def("Conditional", - static_cast, - absl::Span)>(&Conditional), - py::arg("branch_index"), py::arg("branch_computations"), - py::arg("branch_operands")); - ops.def("Conditional", - static_cast(&Conditional), - py::arg("predicate"), py::arg("true_operand"), - py::arg("true_computation"), py::arg("false_operand"), - py::arg("false_computation")); - ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); - ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), - py::arg("literal")); - ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), - py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), - py::arg("lhs_dilation"), py::arg("rhs_dilation"), - py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, - py::arg("batch_group_count") = 1, - py::arg("precision_config") = nullptr); - ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), - py::arg("new_element_type")); - ops.def( - "CustomCall", - [](XlaBuilder* builder, const py::bytes& call_target_name, - absl::Span operands, const Shape& shape, - const py::bytes& opaque) -> XlaOp { - return CustomCall(builder, call_target_name, operands, shape, opaque); - }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape"), py::arg("opaque") = py::bytes("")); - ops.def( - "CustomCallWithLayout", - [](XlaBuilder* builder, const py::bytes& call_target_name, - absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const py::bytes& opaque) -> XlaOp { - return CustomCallWithLayout(builder, call_target_name, operands, - shape_with_layout, - operand_shapes_with_layout, opaque); - }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), - py::arg("opaque") = py::bytes("")); - ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), - py::arg("precision_config") = nullptr); - ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), - py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); - ops.def("DynamicSlice", - static_cast, - absl::Span)>(&DynamicSlice), - py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); - ops.def("DynamicUpdateSlice", - static_cast)>( - &DynamicUpdateSlice), - py::arg("operand"), py::arg("update"), py::arg("start_indices")); - - ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), - py::arg("fft_length")); - - ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), - py::arg("dimension_numbers"), py::arg("slice_sizes"), - py::arg("indices_are_sorted") = false); - ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), - py::arg("index")); - ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), - py::arg("shape"), py::arg("config") = ""); - ops.def("Iota", - static_cast(&Iota), - py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); - ops.def("Iota", - static_cast(&Iota), - py::arg("builder"), py::arg("type"), py::arg("size")); - ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), - py::arg("computation"), py::arg("dimensions"), - py::arg("static_operands") = py::list()); - ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); - ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), - py::arg("token"), py::arg("shape_with_layout"), - py::arg("outfeed_config") = ""); - ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), - py::arg("padding_config")); - ops.def("Parameter", - static_cast&)>( - &Parameter), - py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), - py::arg("name") = "", - py::arg("replicated_at_leaf_buffers") = std::vector()); - ops.def( - "QR", - [](XlaOp a, bool full_matrices) -> StatusOr> { - TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); - return std::make_pair(qr.q, qr.r); - }, - py::arg("operand"), py::arg("full_matrices")); - ops.def( - "Eigh", - [](XlaOp a, bool lower, int64 max_iter, - float epsilon) -> std::pair { - auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); - return std::make_pair(eigh.v, eigh.w); - }, - py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, - py::arg("epsilon") = 1e-6); - ops.def( - "SVD", - [](XlaOp a, int64 max_iter, - float epsilon) -> std::tuple { - auto svd = SVD(a, max_iter, epsilon); - return std::make_tuple(svd.u, svd.d, svd.v); - }, - py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); - ops.def("Reduce", - static_cast, - absl::Span, const XlaComputation&, - absl::Span)>(&Reduce), - py::arg("builder"), py::arg("operands"), py::arg("init_values"), - py::arg("computation"), py::arg("dimensions_to_reduce")); - ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), - py::arg("exponent_bits"), py::arg("mantissa_bits")); - ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding, - py::arg("operand"), py::arg("init_value"), py::arg("computation"), - py::arg("window_dimensions"), py::arg("window_strides"), - py::arg("base_dilations"), py::arg("window_dilations"), - py::arg("padding")); - ops.def("ReplicaId", &ReplicaId, py::arg("builder")); - ops.def("Reshape", - static_cast, - absl::Span)>(&Reshape), - py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); - ops.def("Reshape", - static_cast)>(&Reshape), - py::arg("operand"), py::arg("new_sizes")); - ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); - ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), - py::arg("shape")); - ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), - py::arg("shape")); - ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"), - py::arg("updates"), py::arg("update_computation"), - py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false, - py::arg("unique_indices") = false); - ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), - py::arg("on_false")); - ops.def("SelectAndScatterWithGeneralPadding", - &SelectAndScatterWithGeneralPadding, py::arg("operand"), - py::arg("select"), py::arg("window_dimensions"), - py::arg("window_strides"), py::arg("padding"), py::arg("source"), - py::arg("init_value"), py::arg("scatter")); - ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), - py::arg("limit_indices"), py::arg("strides")); - ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), - py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); - ops.def( - "Sort", - [](XlaBuilder* builder, absl::Span operands, - absl::optional comparator, int64 dimension, - bool is_stable) -> XlaOp { - return builder->ReportErrorOrReturn([&]() -> StatusOr { - std::vector operand_types; - for (const auto& operand : operands) { - TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); - operand_types.push_back(operand_shape.element_type()); - } - - if (comparator) { - return Sort(operands, **comparator, dimension, is_stable); - } else { - return Sort(operands, - CreateScalarLtComputation(operand_types, builder), - dimension, is_stable); - } - }); - }, - py::arg("builder"), py::arg("operands"), - py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1, - py::arg("is_stable") = false); - ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); - ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); - ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), - py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), - py::arg("transpose_a")); - ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); - ops.def("While", &While, py::arg("condition"), py::arg("body"), - py::arg("init")); - - ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); - ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); - ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); - ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); - ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), - py::arg("b"), py::arg("x")); - -#define BINARY_OP(op) \ - ops.def( \ - #op, \ - [](XlaOp a, XlaOp b, absl::optional> dims) { \ - return dims ? op(a, b, *dims) : op(a, b); \ - }, \ - py::arg("lhs"), py::arg("rhs"), \ - py::arg("broadcast_dimensions") = absl::nullopt) - BINARY_OP(Eq); - BINARY_OP(Ne); - BINARY_OP(Ge); - BINARY_OP(Gt); - BINARY_OP(Lt); - BINARY_OP(Le); - BINARY_OP(Add); - BINARY_OP(Sub); - BINARY_OP(Mul); - BINARY_OP(Div); - BINARY_OP(Rem); - BINARY_OP(Max); - BINARY_OP(Min); - BINARY_OP(And); - BINARY_OP(Or); - BINARY_OP(Xor); - BINARY_OP(ShiftLeft); - BINARY_OP(ShiftRightArithmetic); - BINARY_OP(ShiftRightLogical); - BINARY_OP(Atan2); - BINARY_OP(Pow); - BINARY_OP(Complex); -#undef BINARY_OP - -#define UNARY_OP(op) ops.def(#op, &op) - UNARY_OP(Not); - UNARY_OP(PopulationCount); - UNARY_OP(Clz); - UNARY_OP(Abs); - UNARY_OP(Exp); - UNARY_OP(Expm1); - UNARY_OP(Floor); - UNARY_OP(Ceil); - UNARY_OP(Round); - UNARY_OP(Log); - UNARY_OP(Log1p); - UNARY_OP(Sign); - UNARY_OP(Cos); - UNARY_OP(Sin); - UNARY_OP(Tanh); - UNARY_OP(IsFinite); - UNARY_OP(Neg); - UNARY_OP(Sqrt); - UNARY_OP(Rsqrt); - UNARY_OP(Square); - UNARY_OP(Reciprocal); - UNARY_OP(Erfc); - UNARY_OP(Erf); - UNARY_OP(ErfInv); - UNARY_OP(Lgamma); - UNARY_OP(Digamma); - UNARY_OP(BesselI0e); - UNARY_OP(BesselI1e); - UNARY_OP(Acos); - UNARY_OP(Asin); - UNARY_OP(Atan); - UNARY_OP(Tan); - UNARY_OP(Acosh); - UNARY_OP(Asinh); - UNARY_OP(Atanh); - UNARY_OP(Cosh); - UNARY_OP(Sinh); - UNARY_OP(Real); - UNARY_OP(Imag); - UNARY_OP(Conj); -#undef UNARY_OP -} - -// Helper to implement TraceMe as a context manager in Python. -class TraceMeContextManager { - public: - explicit TraceMeContextManager(py::str name, py::kwargs kwargs) - : name_(std::move(name)), kwargs_(std::move(kwargs)) {} - - void Enter() { - if (IsEnabled()) { - std::string name(name_); - if (!kwargs_.empty()) { - absl::StrAppend(&name, "#"); - bool first = true; - for (const auto entry : kwargs_) { - absl::StrAppend(&name, first ? "" : ",", - std::string(py::str(entry.first)), "=", - std::string(py::str(entry.second))); - first = false; - } - absl::StrAppend(&name, "#"); - } - traceme_.emplace(std::move(name)); - } - } - py::object Exit(const py::object& ex_type, const py::object& ex_value, - const py::object& traceback) { - traceme_.reset(); - return py::none(); - } - - static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); } - - private: - py::str name_; - py::kwargs kwargs_; - absl::optional traceme_; -}; void BuildProfilerSubmodule(py::module* m) { py::module profiler = @@ -672,11 +318,19 @@ void BuildProfilerSubmodule(py::module* m) { }, py::arg("port")); - py::class_ traceme_class(profiler, "TraceMe"); + py::class_ traceme_class(profiler, "TraceMe", + py::module_local()); traceme_class.def(py::init()) - .def("__enter__", &TraceMeContextManager::Enter) - .def("__exit__", &TraceMeContextManager::Exit) - .def_static("is_enabled", &TraceMeContextManager::IsEnabled); + .def("__enter__", [](py::object self) -> py::object { return self; }) + .def("__exit__", + [](py::object self, const py::object& ex_type, + const py::object& ex_value, + const py::object& traceback) -> py::object { + py::cast(self)->Stop(); + return py::none(); + }) + .def("set_metadata", &TraceMeWrapper::SetMetadata) + .def_static("is_enabled", &TraceMeWrapper::IsEnabled); } } // namespace @@ -872,11 +526,7 @@ PYBIND11_MODULE(xla_extension, m) { DebugOptions* debug_options = options.executable_build_options.mutable_debug_options(); // Sets fast-math-disabling default options expected by JAX. - // TODO(phawkins): make these XLA-wide defaults. - debug_options->set_xla_cpu_fast_math_honor_infs(true); - debug_options->set_xla_cpu_fast_math_honor_nans(true); - debug_options->set_xla_cpu_fast_math_honor_division(true); - debug_options->set_xla_cpu_fast_math_honor_functions(true); + debug_options->set_xla_cpu_enable_fast_min_max(false); debug_options->set_xla_gpu_enable_fast_min_max(false); return options; })) @@ -934,34 +584,6 @@ PYBIND11_MODULE(xla_extension, m) { "client", [](const ClientAndPtr& device) { return device.client; }) .def("__str__", &Device::DebugString) - // TODO(phawkins): remove capitalized names after updating callers. - .def("TransferToInfeed", - [](const Device& device, const LiteralSlice& literal) { - GlobalPyRefManager()->CollectGarbage(); - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - return local_device->client()->TransferToInfeedLocal( - literal, local_device->device_ordinal()); - }) - .def( - "TransferFromOutfeed", - [](const Device& device, const Shape& shape) -> StatusOr { - GlobalPyRefManager()->CollectGarbage(); - std::shared_ptr literal_shared; - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - TF_ASSIGN_OR_RETURN( - Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape, local_device->device_ordinal())); - - literal_shared = std::make_shared(std::move(literal)); - } - return LiteralToPython(std::move(literal_shared)); - }) .def("transfer_to_infeed", [](const Device& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); @@ -1248,28 +870,6 @@ PYBIND11_MODULE(xla_extension, m) { .def("size_of_generated_code_in_bytes", &PjRtExecutable::SizeOfGeneratedCodeInBytes) .def("delete", &PjRtExecutable::Delete) - // TODO(phawkins): delete capitalized methods after updating callers. - .def("Delete", &PjRtExecutable::Delete) - .def( - "Execute", - [](const PjRtExecutable& executable, - absl::Span args) - -> StatusOr>> { - py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; - TF_ASSIGN_OR_RETURN( - std::vector> output_buffers, - executable.Execute(args, options)); - std::vector> outputs; - outputs.reserve(output_buffers.size()); - for (auto& buffer : output_buffers) { - outputs.push_back(WrapWithClient( - executable.client()->shared_from_this(), std::move(buffer))); - } - return outputs; - }, - py::arg("arguments")) .def( "execute", [](const PjRtExecutable& executable, @@ -1290,33 +890,6 @@ PYBIND11_MODULE(xla_extension, m) { return outputs; }, py::arg("arguments")) - // TODO(phawkins): delete capitalized methods after updating callers. - .def( - "ExecuteOnLocalDevices", - [](const PjRtExecutable& executable, - absl::Span> args) - -> StatusOr< - std::vector>>> { - py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; - TF_ASSIGN_OR_RETURN( - std::vector>> - output_buffers, - executable.ExecuteOnLocalDevices(args, options)); - std::vector>> outputs; - outputs.resize(output_buffers.size()); - for (int computation = 0; computation < output_buffers.size(); - ++computation) { - for (auto& buffer : output_buffers[computation]) { - outputs[computation].push_back( - WrapWithClient(executable.client()->shared_from_this(), - std::move(buffer))); - } - } - return outputs; - }, - py::arg("arguments")) .def( "execute_on_local_devices", [](const PjRtExecutable& executable, @@ -1377,7 +950,19 @@ PYBIND11_MODULE(xla_extension, m) { &DebugOptions::set_xla_cpu_fast_math_honor_functions) .def_property("xla_gpu_enable_fast_min_max", &DebugOptions::xla_gpu_enable_fast_min_max, - &DebugOptions::set_xla_gpu_enable_fast_min_max); + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_property("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_property("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_property("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_property("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts); py::class_(m, "ExecutableBuildOptions") .def(py::init<>()) @@ -1406,7 +991,10 @@ PYBIND11_MODULE(xla_extension, m) { options.device_assignment()) : absl::nullopt; }, - &ExecutableBuildOptions::set_device_assignment); + &ExecutableBuildOptions::set_device_assignment) + .def_property("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning); py::class_(m, "XlaComputation") .def(py::init([](const py::bytes& serialized_hlo_module_proto) @@ -1415,12 +1003,6 @@ PYBIND11_MODULE(xla_extension, m) { proto.ParseFromString(serialized_hlo_module_proto); return absl::make_unique(proto); })) - // TODO(phawkins): delete capitalized names after updating callers. - .def("GetProgramShape", &XlaComputation::GetProgramShape) - .def("GetSerializedProto", &GetComputationSerializedProto) - .def("GetHloText", &GetComputationHloText) - .def("GetHloDotGraph", &GetComputationHloDotGraph) - .def("Hash", &HashComputation) .def("get_hlo_module", &GetHloModule) .def("program_shape", &XlaComputation::GetProgramShape) .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto) @@ -1513,28 +1095,7 @@ PYBIND11_MODULE(xla_extension, m) { }, "Builds a computation from the contents of the builder.", py::arg("root") = absl::nullopt) - .def("ClearOpMetadata", &XlaBuilder::ClearOpMetadata) .def("GetShape", &XlaBuilder::GetShape) - .def( - "GetProgramShape", - [](const XlaBuilder& builder, - absl::optional root) -> StatusOr { - return root ? builder.GetProgramShape(*root) - : builder.GetProgramShape(); - }, - py::arg("root") = absl::nullopt) - .def("IsConstant", &XlaBuilder::IsConstant) - .def("SetOpMetadata", &XlaBuilder::SetOpMetadata) - .def("SetSharding", &XlaBuilder::SetSharding) - .def("ClearSharding", &XlaBuilder::ClearSharding) - .def("SetUpAlias", - [](XlaBuilder& builder, const std::vector& output_index, - int64 param_number, const std::vector& param_index) { - builder.SetUpAlias( - ShapeIndex(output_index.begin(), output_index.end()), - param_number, - ShapeIndex(param_index.begin(), param_index.end())); - }) .def( "build", [](XlaBuilder& builder, absl::optional root) { @@ -1565,17 +1126,7 @@ PYBIND11_MODULE(xla_extension, m) { ShapeIndex(param_index.begin(), param_index.end())); }); - // TODO(phawkins): delete capitalized names after updating callers - m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor); - m.def("DLPackManagedTensorToBuffer", - [](const py::capsule& tensor, std::shared_ptr client) - -> StatusOr> { - TF_ASSIGN_OR_RETURN( - std::unique_ptr buffer, - DLPackManagedTensorToBuffer(tensor, client.get())); - return WrapWithClient(std::move(client), std::move(buffer)); - }); m.def("dlpack_managed_tensor_to_buffer", [](const py::capsule& tensor, std::shared_ptr client) -> StatusOr> { @@ -1615,6 +1166,7 @@ PYBIND11_MODULE(xla_extension, m) { BuildOpsSubmodule(&m); BuildProfilerSubmodule(&m); + BuildOutfeedReceiverSubmodule(&m); py::class_> diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index d9cd906939d..76c3bc33a91 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -300,13 +300,13 @@ CompileOptions = _xla.CompileOptions # An Executable is a C++ class that duck types with the following API: # class Executable(object): # def local_devices(self) -> [Device]: -# def Execute(self, arguments : [Buffer]) -> Buffer: +# def execute(self, arguments : [Buffer]) -> Buffer: # """Execute on one replica with Buffer arguments and return value.""" # -# def SizeOfGeneratedCodeInBytes(self) -> int: +# def size_of_generated_code_in_bytes(self) -> int: # """Return generated binary size, or -1 if not known.""" # -# def ExecuteOnLocalDevices(self, arguments: [[Buffer]]) -> [Buffer]: +# def execute_on_local_devices(self, arguments: [[Buffer]]) -> [Buffer]: # """Execute on many replicas with Buffer arguments and return value. # # Args: @@ -329,7 +329,7 @@ def execute_with_python_values(executable, arguments, backend): return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) arguments = [put(arg) for arg in arguments] - outputs = executable.Execute(arguments) + outputs = executable.execute(arguments) return [x.to_py() for x in outputs] @@ -359,7 +359,7 @@ def execute_with_python_values_replicated(executable, arguments, backend): flat_arg_buffers = flat_arg_buffers[len(replica_args):] return [[x.to_py() for x in xs] - for xs in executable.ExecuteOnLocalDevices(arg_buffers)] + for xs in executable.execute_on_local_devices(arg_buffers)] class PaddingType(enum.Enum): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 62b3fae018a..000db2cb16b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -115,6 +115,10 @@ def TestFactory(xla_backend, cloud_tpu=False): """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" return np.array(*args, dtype=np.float32, **kwargs) + def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + def NumpyArrayS32(*args, **kwargs): """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" return np.array(*args, dtype=np.int32, **kwargs) @@ -882,12 +886,20 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Abs(ops.Constant(c, arr)) self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) - def testTanh(self): + def testTanhF32(self): c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) ops.Tanh(ops.Constant(c, arr)) self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) + def testTanhF64(self): + if self.backend.platform == "tpu": + self.skipTest("TPU doesn't support 64bit tanh") + c = self._NewComputation() + arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) + def testTranspose(self): def _TransposeAndTest(array, permutation): @@ -2029,8 +2041,11 @@ def TestFactory(xla_backend, cloud_tpu=False): return tests -def InstantiateTests(globals_dict, backend, test_prefix="", **kw): - for klass in TestFactory(backend, **kw): +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): test = type(test_prefix + klass.__name__, (klass,), {}) # Clean up the qualified names of the tests to not include the test factory. test.__qualname__ = test.__name__ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 799d5654840..dfc9aae94e0 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -460,6 +460,97 @@ cc_library( ], ) +cc_library( + name = "hlo_sharding_util", + srcs = [ + "hlo_sharding_util.cc", + ], + hdrs = [ + "hlo_sharding_util.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "hlo_sharding_util_test", + srcs = [ + "hlo_sharding_util_test.cc", + ], + deps = [ + ":hlo_sharding_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "sharding_propagation", + srcs = [ + "sharding_propagation.cc", + ], + hdrs = [ + "sharding_propagation.h", + ], + deps = [ + ":dot_as_convolution_util", + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + ":hlo_sharding_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "sharding_propagation_test", + srcs = [ + "sharding_propagation_test.cc", + ], + deps = [ + "hlo_matchers", + ":hlo_parser", + ":sharding_propagation", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "dot_as_convolution_util", + srcs = [ + "dot_as_convolution_util.cc", + ], + hdrs = [ + "dot_as_convolution_util.h", + ], + deps = [ + ":hlo", + ":shape_inference", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@com_google_absl//absl/types:optional", + ], +) + tf_cc_test( name = "dynamic_parameter_binding_test", srcs = ["dynamic_parameter_binding_test.cc"], @@ -2098,6 +2189,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2397,6 +2490,42 @@ tf_cc_test( ], ) +cc_library( + name = "all_gather_decomposer", + srcs = ["all_gather_decomposer.cc"], + hdrs = ["all_gather_decomposer.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "all_gather_decomposer_test", + srcs = ["all_gather_decomposer_test.cc"], + deps = [ + ":all_gather_decomposer", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -3217,6 +3346,7 @@ cc_library( ":heap_simulator", ":hlo_cost_analysis", "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/core/lib/math:math_util", ], ) @@ -3234,6 +3364,29 @@ tf_cc_test( ], ) +cc_library( + name = "memory_space_propagation", + srcs = ["memory_space_propagation.cc"], + hdrs = ["memory_space_propagation.h"], + deps = [ + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_pass", + ], +) + +tf_cc_test( + name = "memory_space_propagation_test", + srcs = ["memory_space_propagation_test.cc"], + deps = [ + ":hlo_parser", + ":memory_space_propagation", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_dce", srcs = ["hlo_dce.cc"], @@ -3787,6 +3940,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:core", "@llvm-project//llvm:transform_utils", ], @@ -4406,6 +4560,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 55af8726dc8..4025cb46f18 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -472,8 +472,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { HloInstruction* dot); HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { - if (scalar_add_computation_) { - return scalar_add_computation_; + HloComputation*& scalar_add_computation = scalar_add_computations_[type]; + if (scalar_add_computation) { + return scalar_add_computation; } HloComputation::Builder b("scalar_add_computation"); @@ -485,9 +486,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - scalar_add_computation_ = + scalar_add_computation = computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_add_computation_; + return scalar_add_computation; } // Tries to fold a kPad in the input or filter into the convolution @@ -508,6 +509,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into + // `(< a N)`. This is crucial for being able to figure out the loop trip + // count. + // + // Assumes that the input is conjunction. + StatusOr TrySimplifyTautologicalCompare(HloInstruction* conjunction); + // Useful when we want to use the same visitor over multiple computations. void ResetState(HloComputation* computation); @@ -521,8 +529,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Whether algebraic simplification has occurred. bool changed_ = false; - // Cached computation for adding two scalar F32. - HloComputation* scalar_add_computation_ = nullptr; + // Cached computation for adding two scalars of a given type. + absl::flat_hash_map scalar_add_computations_; AlgebraicSimplifier* simplifier_ = nullptr; }; @@ -856,6 +864,50 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( + HloInstruction* conjunction) { + HloInstruction *lhs, *rhs; + if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { + return false; + } + struct LessThanCompareInfo { // (LT var constant) + HloInstruction* var; + int64 constant; + }; + + auto get_compare_info = + [&](HloInstruction* cmp) -> absl::optional { + HloInstruction *lhs, *rhs; + auto scalar_shape_matcher = + m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32); + if (Match(cmp, m::Compare(m::Op(&lhs), + m::Constant(&rhs).WithShape(scalar_shape_matcher)) + .WithComparisonDirection(ComparisonDirection::kLt))) { + return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}}; + } else if (Match( + cmp, + m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher), + m::Op(&rhs)) + .WithComparisonDirection(ComparisonDirection::kGt))) { + return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}}; + } + return absl::nullopt; + }; + + absl::optional lhs_info = get_compare_info(lhs); + absl::optional rhs_info = get_compare_info(rhs); + if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) { + int64 new_bound = std::min(lhs_info->constant, rhs_info->constant); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + conjunction, + HloInstruction::CreateCompare(lhs->shape(), lhs_info->var, + MakeScalarLike(lhs_info->var, new_bound), + ComparisonDirection::kLt))); + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { HloInstruction *lhs, *rhs; CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); @@ -890,6 +942,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { return Status::OK(); } + // Simplify tautological conjunctions. + TF_ASSIGN_OR_RETURN(bool found_tautological_compare, + TrySimplifyTautologicalCompare(logical_and)); + if (found_tautological_compare) { + return Status::OK(); + } + return Status::OK(); } @@ -1423,6 +1482,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return ReplaceInstruction(divide, new_divide); } + // If X is a convert from pred, then + // X / broadcast(Y) => broadcast(1/Y) * X + if (Match(divide, + m::Divide( + m::Convert(&a, + m::Op().WithShape(m::Shape().WithElementType(PRED))), + m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { + TF_ASSIGN_OR_RETURN( + auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b)); + auto recip_bcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(divide->shape(), recip, {})); + TF_ASSIGN_OR_RETURN(auto mul, + MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); + return ReplaceInstruction(divide, mul); + } + return Status::OK(); } @@ -2983,6 +3058,20 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( return false; } HloInstruction* operand = broadcast->mutable_operand(0); + auto is_scalar_broadcast = [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(instruction->operand(0)->shape()); + }; + auto is_equal_broadcast = [operand, + broadcast](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::Equal(operand->shape(), + instruction->operand(0)->shape()) && + broadcast->dimensions() == instruction->dimensions(); + }; + auto is_compatible_broadcast = [&](const HloInstruction* instruction) { + return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction); + }; for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; @@ -3001,18 +3090,20 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( continue; } - // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_broadcast_count = 0; + // Check if all the operands of the user are compatible broadcasts for + // sinking. (They are either scalar broadcasts or broadcasts casting + // from/to the same shape/dimensions) + int64 compatible_broadcast_count = 0; int64 broadcast_use_count = 0; for (HloInstruction* user_operand : user->operands()) { - if (user_operand->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { - ++scalar_broadcast_count; + if (is_compatible_broadcast(user_operand)) { + ++compatible_broadcast_count; } else if (broadcast == user_operand) { ++broadcast_use_count; } } - if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { + if (compatible_broadcast_count + broadcast_use_count != + user->operand_count()) { continue; } std::vector new_operands; @@ -3020,14 +3111,24 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( Shape changed_shape; for (HloInstruction* user_operand : user->operands()) { - if (user_operand->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { - changed_shape = ShapeUtil::ChangeElementType( - operand->shape(), user_operand->shape().element_type()); - simplifier_->UpdateLayout(&changed_shape); - new_operands.push_back( - computation_->AddInstruction(HloInstruction::CreateBroadcast( - changed_shape, user_operand->mutable_operand(0), {}))); + // If this is a broadcast operand that is not our original broadcast input + // to this function then we might need to change the input. + if (is_compatible_broadcast(user_operand)) { + // If this is a broadcast from a scalar value rewrite a broadcast from + // the scalar to the new shape enforced from the other broadcast + // operands. + if (is_scalar_broadcast(user_operand)) { + changed_shape = ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + changed_shape, user_operand->mutable_operand(0), {}))); + } else { + // For the non-scalar broadcasts we guarantee that the shape of the + // operand of the broadcast needs to be already a compatible shape. + new_operands.push_back(user_operand->mutable_operand(0)); + } } else { CHECK_EQ(broadcast, user_operand); new_operands.push_back(operand); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 6c8e80aa963..bcfc2fdc740 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -338,6 +338,79 @@ TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) { m::ConstantScalar(3.0)))))); } +TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + b0 = f32[4] broadcast(p0), dimensions={} + b1 = f32[4] broadcast(p1), dimensions={} + ROOT multiply = f32[4] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)), + m::Broadcast(m::Parameter(0)))))); +} + +TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(2.0) + b0 = f32[4,2] broadcast(c0), dimensions={} + b1 = f32[4,2] broadcast(p0), dimensions={0} + ROOT multiply = f32[4,2] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Multiply( + m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0)))))); +} + +TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + b0 = f32[4,2] broadcast(p0), dimensions={0} + b1 = f32[4,2] broadcast(p1), dimensions={0} + ROOT multiply = f32[4,2] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0))))); +} + +TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[8] parameter(1) + b0 = f32[4,8] broadcast(p0), dimensions={0} + b1 = f32[4,8] broadcast(p1), dimensions={1} + ROOT multiply = f32[4,8] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)), + m::Broadcast(m::Parameter(0))))); +} + TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMultiplyOfConstantAndBroadcast) { const char* kModuleStr = R"( @@ -5761,6 +5834,44 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) { GmockMatch(m::Broadcast(m::ConstantScalar(true)))); } +TEST_F(AlgebraicSimplifierTest, CompareSimplified) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[] parameter(0) + c1 = s32[] constant(10) + c2 = s32[] constant(100) + cmp1 = pred[] compare(param, c1), direction=LT + cmp2 = pred[] compare(param, c2), direction=LT + ROOT out = pred[] and(cmp1, cmp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10)) + .WithComparisonDirection(ComparisonDirection::kLt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[] parameter(0) + c1 = s32[] constant(10) + c2 = s32[] constant(100) + cmp1 = pred[] compare(param, c1), direction=LT + cmp2 = pred[] compare(c2, param), direction=GT + ROOT out = pred[] and(cmp1, cmp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10)) + .WithComparisonDirection(ComparisonDirection::kLt))); +} + TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { // Some backends may have better performance by treating an outer product as a // Dot, rather than a broadcast Multiply @@ -6462,5 +6573,43 @@ TEST_F(AlgebraicSimplifierTest, SwapConvOperands) { EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1); } +TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[2] parameter(0) + cvt = f32[2] convert(p0) + p1 = f32[] parameter(1) + bcast = f32[2] broadcast(p1), dimensions={} + ROOT div = f32[2] divide(cvt, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Convert(m::Parameter(0)), + m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1)))))); +} + +TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) { + constexpr char kModuleStr[] = R"( + HloModule test + ENTRY test { + a = c64[2,2] parameter(0) + b = c64[2] parameter(1) + cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = f64[2,2] parameter(2) + d = f64[2] parameter(3) + dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = (c64[2], f64[2]) tuple(cd, dd) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_EQ(3, m->computation_count()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.cc b/tensorflow/compiler/xla/service/all_gather_decomposer.cc new file mode 100644 index 00000000000..00b9adaea43 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.cc @@ -0,0 +1,157 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_gather_decomposer.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +// Creates a computation of x + y. +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + if (type == PRED) { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); + } else { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + } + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { + const int64 shard_size = + ag->operand(0)->shape().dimensions(ag->all_gather_dimension()); + const int64 ag_size = ag->shape().dimensions(ag->all_gather_dimension()); + TF_RET_CHECK(ag_size % shard_size == 0); + int64 partition_count = ag_size / shard_size; + auto zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(ag->shape().element_type()))); + zero = comp->AddInstruction( + HloInstruction::CreateBroadcast(ag->shape(), zero, {})); + auto zero_index = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(U32))); + std::vector start_indices(ag->shape().rank(), zero_index); + auto shard_id_from_subgroup = [&](HloInstruction* replica_or_global_id) { + if (ag->replica_groups().empty()) { + return replica_or_global_id; + } + if (ag->replica_groups().size() == 1) { + // Whether the group is {1, 2, ..., N - 1}. + bool trivial_group = true; + for (int64 i = 0; i < ag->replica_groups()[0].replica_ids_size(); ++i) { + if (ag->replica_groups()[0].replica_ids(i) != i) { + trivial_group = false; + break; + } + } + if (trivial_group) { + CHECK_EQ(partition_count, ag->replica_groups()[0].replica_ids_size()); + return replica_or_global_id; + } + } + // Create a table of shard IDs for each replica_or_global_id, then slice it + // using replica_or_global_id. + std::vector shard_ids(ag->replica_groups().size() * + ag->replica_groups()[0].replica_ids_size()); + for (const auto& group : ag->replica_groups()) { + for (int64 i = 0; i < group.replica_ids_size(); ++i) { + shard_ids[group.replica_ids(i)] = i; + } + } + auto id_table = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(shard_ids))); + auto shard_id = comp->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(U32, {1}), id_table, {replica_or_global_id}, {1})); + shard_id = comp->AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(U32, {}), shard_id)); + return shard_id; + }; + HloInstruction* shard_id; + if (ag->channel_id().has_value()) { + if (ag->use_global_device_ids()) { + auto pid = comp->AddInstruction(HloInstruction::CreatePartitionId()); + auto rid = comp->AddInstruction(HloInstruction::CreateReplicaId()); + auto pcount = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(partition_count))); + auto global_id = comp->AddInstruction(HloInstruction::CreateBinary( + pid->shape(), HloOpcode::kAdd, pid, + comp->AddInstruction(HloInstruction::CreateBinary( + pid->shape(), HloOpcode::kMultiply, rid, pcount)))); + shard_id = shard_id_from_subgroup(global_id); + } else { + TF_RET_CHECK(!ag->replica_groups().empty()); + TF_RET_CHECK(ag->replica_groups()[0].replica_ids_size() == 1); + shard_id = comp->AddInstruction(HloInstruction::CreatePartitionId()); + } + } else { + shard_id = shard_id_from_subgroup( + comp->AddInstruction(HloInstruction::CreateReplicaId())); + } + start_indices[ag->all_gather_dimension()] = + comp->AddInstruction(HloInstruction::CreateBinary( + shard_id->shape(), HloOpcode::kMultiply, shard_id, + comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(shard_size))))); + auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + zero->shape(), zero, ag->mutable_operand(0), start_indices)); + auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce( + dus->shape(), {dus}, + MakeBinaryAdd(dus->shape().element_type(), comp->parent()), + ag->replica_groups(), + /*constrain_layout=*/ag->constrain_layout(), ag->channel_id(), + ag->use_global_device_ids())); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag)); + return Status::OK(); +} + +StatusOr AllGatherDecomposer::Run(HloModule* module) { + bool changed = false; + for (auto comp : module->MakeNonfusionComputations()) { + for (auto hlo : comp->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kAllGather) { + continue; + } + auto ag = Cast(hlo); + if (should_decompose_(*ag)) { + TF_RETURN_IF_ERROR(DecomposeAllGather(ag, comp)); + changed = true; + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.h b/tensorflow/compiler/xla/service/all_gather_decomposer.h new file mode 100644 index 00000000000..6b20765c709 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// AllGatherDecomposer is a pass which converts unsupported all-gathers into +// dynamic-update-slices and all-reduces. +class AllGatherDecomposer : public HloModulePass { + public: + explicit AllGatherDecomposer( + std::function should_decompose) + : should_decompose_(std::move(should_decompose)) {} + AllGatherDecomposer() + : should_decompose_( + [](const HloAllGatherInstruction& ag) { return true; }) {} + absl::string_view name() const override { return "all_gather_decomposer"; } + + // Run AllGatherDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + std::function should_decompose_; + int64 partition_count_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc new file mode 100644 index 00000000000..3df5e51a7c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_gather_decomposer.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; +using AllGatherDecomposerTest = HloTestBase; + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGather) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossPartitionAllGather) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0}}, channel_id=1, + dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::PartitionId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithTrivialGroup) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0,1,2,3}}, + dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithSubgroups) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), + replica_groups={{2,1,0,3}, {4,6,7,5}}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + auto id = + AllOf(op::Shape("u32[]"), + op::Reshape(op::DynamicSlice(op::Constant(), op::ReplicaId()))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), + op::Constant(), op::Multiply(id, op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithSubgroupsGlobalIds) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), + replica_groups={{2,1,0,3}, {4,6,7,5}}, dimensions={1}, channel_id=1, + use_global_device_ids=true +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer; + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + auto global_id = + op::Add(op::PartitionId(), op::Multiply(op::ReplicaId(), op::Constant())); + auto id = AllOf(op::Shape("u32[]"), + op::Reshape(op::DynamicSlice(op::Constant(), global_id))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), + op::Constant(), op::Multiply(id, op::Constant())))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index abb695fa486..30d764225c2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -79,6 +79,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( const HloInstruction& hlo, int64 operand_index) { switch (hlo.opcode()) { case HloOpcode::kAbs: + case HloOpcode::kAllGather: case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 67cdb081a91..6cd58b86f0c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -261,7 +261,7 @@ void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset, Shape* shape = ShapeUtil::GetMutableSubshape( position.instruction->mutable_shape(), position.index); if (shape->has_layout()) { - shape->mutable_layout()->set_memory_space(buffer.color().value()); + shape->mutable_layout()->set_memory_space(buffer.color()); } } } @@ -272,7 +272,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto.set_size(size_); proto.set_is_thread_local(is_thread_local_); proto.set_is_tuple(is_tuple_); - proto.set_color(color_.value()); + proto.set_color(color_); if (is_entry_computation_parameter_) { proto.set_is_entry_computation_parameter(true); for (int64 idx : param_shape_index()) { @@ -336,8 +336,8 @@ static const HloInstruction* GetOutputInstruction( string BufferAllocation::ToString() const { string output; StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size()); - if (color().value() != 0) { - StrAppend(&output, ", color ", color().value()); + if (color() != 0) { + StrAppend(&output, ", color ", color()); } if (is_entry_computation_parameter()) { const HloInstruction* param = GetEntryParameterInstruction(*this); @@ -607,9 +607,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // BufferAllocation. void BufferAssignment::CombineTempAllocations() { VLOG(1) << "CombineTempAllocations()"; - flat_hash_map - combined_allocation_map; + flat_hash_map combined_allocation_map; // Move all temp allocations into a single run at the end of the allocations // vector. @@ -1059,8 +1057,8 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { // The instruction or operand color is excluded because it was assigned by // memory_space_assignment. - if (excluded_colors.contains(instruction_buffer.color().value()) || - excluded_colors.contains(operand_buffer.color().value())) { + if (excluded_colors.contains(instruction_buffer.color()) || + excluded_colors.contains(operand_buffer.color())) { continue; } @@ -1353,13 +1351,10 @@ Status BufferAssigner::AssignBuffersForComputations( return Status::OK(); } -flat_hash_map, - LogicalBuffer::Color::Hasher> +flat_hash_map> BufferAssigner::SplitBuffersByColor( const flat_hash_set& buffers) { - flat_hash_map, - LogicalBuffer::Color::Hasher> - color_map; + flat_hash_map> color_map; for (auto buffer : buffers) { color_map[buffer->color()].insert(buffer); } @@ -1374,8 +1369,7 @@ Status BufferAssigner::AssignPresetBuffers( } // Create an allocation for each preset color. - absl::flat_hash_map + absl::flat_hash_map preset_allocations; for (auto& color_and_info : preset_assignments_->assignment_informations()) { LogicalBuffer::Color color(color_and_info.first); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 2a02d3776ce..50a4750601b 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -673,8 +673,7 @@ class BufferAssigner { // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. absl::flat_hash_map, - LogicalBuffer::Color::Hasher> + absl::flat_hash_set> SplitBuffersByColor(const absl::flat_hash_set& buffers); // If true, allocate buffers for constant instructions. diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index b1abba20689..58e8086f5e9 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -59,7 +59,7 @@ LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const { ToLocationProto(*instruction(), index()); proto.mutable_defined_at()->Swap(&proto_location); if (has_color()) { - proto.set_color(color().value()); + proto.set_color(color()); } return proto; } diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h index 44cd7b5ebbd..bd2a09e4aaf 100644 --- a/tensorflow/compiler/xla/service/buffer_value.h +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -86,7 +85,7 @@ namespace xla { class BufferValue { public: - TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); + using Color = int64; // Id is a unique identifier for the BufferValue to facilitate efficient // collections of BufferValues with stable iteration order. @@ -154,7 +153,7 @@ class BufferValue { static LogicalBufferProto::Location ToLocationProto( const HloInstruction& instruction, const ShapeIndex& index); - const Color kInvalidColor = Color(-1); + const Color kInvalidColor = -1; protected: BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 8c76e912011..ce9c8a4ea62 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -91,6 +91,7 @@ CompileOnlyService::CompileAheadOfTime( TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( execution_options.mutable_device_assignment())); } + execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning()); for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_host_program_shape()); *execution_options.mutable_shape_with_output_layout() = diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index cf646159a38..57b24e372e6 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -76,6 +76,7 @@ class AotCompilationOptions { virtual int64 replica_count() const { return 0; } virtual int64 num_cores() const { return 0; } + virtual bool use_spmd_partitioning() const { return false; } // Optional allocator that may be used for allocating temp space on the device // during compilation. diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index f60742a8c23..a606e44c5ef 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -16,11 +16,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include +#include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_graph.h" @@ -163,94 +166,92 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { return true; } StatusOr TryRemoveUnusedConditionalOperands( - HloInstruction* conditional, - std::map>* changed_computations) { - // Avoid dealing with sharding. - if (conditional->has_sharding()) { + HloComputation* computation, + const absl::flat_hash_set& calling_conditionals) { + HloInstruction* param = computation->parameter_instruction(0); + // Do not remove from the root instruction. + if (param == computation->root_instruction()) { return false; } - std::vector> tuple_indices_to_keep( - conditional->branch_count()); - bool will_change = false; - for (int64 i = 0; i < conditional->branch_count(); ++i) { - HloComputation* computation = conditional->branch_computation(i); - if (changed_computations->count(computation) > 0) { - will_change = true; - break; - } - HloInstruction* param = computation->parameter_instruction(0); - // Do not remove the root instruction. - if (param == computation->root_instruction()) { - return false; - } - // There is nothing to be removed for non-tuple operands. - if (!param->shape().IsTuple()) { - return false; - } - for (HloInstruction* user : param->users()) { - // If the user is not a get tuple element, assume it is unsafe to remove - // elements from the tuple. - if (user->opcode() != HloOpcode::kGetTupleElement) { - return false; - } - tuple_indices_to_keep[i].insert(user->tuple_index()); - } - // If not all tuple elements are used in this conditional branch, some can - // removed from the computation. - if (tuple_indices_to_keep[i].size() != - ShapeUtil::TupleElementCount(param->shape())) { - will_change = true; - } + // There is nothing to be removed for non-tuple operands. + if (!param->shape().IsTuple()) { + return false; } - - if (!will_change) { + std::set tuple_indices_to_keep; + for (HloInstruction* user : param->users()) { + // If the user is not a get tuple element, assume it is unsafe to remove + // elements from the tuple. + if (user->opcode() != HloOpcode::kGetTupleElement) { + return false; + } + tuple_indices_to_keep.insert(user->tuple_index()); + } + // If all tuple elements are used in this conditional branch, there is nothing + // to be removed. + int64 old_tuple_element_count = ShapeUtil::TupleElementCount(param->shape()); + if (tuple_indices_to_keep.size() == old_tuple_element_count) { return false; } - for (int64 branch = 0; branch < conditional->branch_count(); ++branch) { - const Shape& old_shape = conditional->operand(branch + 1)->shape(); - int64 old_tuple_element_count = ShapeUtil::TupleElementCount(old_shape); - // Clone the computation in case it is called by another instruction. - HloComputation* computation = conditional->branch_computation(branch); - if (changed_computations - ->insert({computation, tuple_indices_to_keep[branch]}) - .second) { - HloInstruction* param = computation->parameter_instruction(0); + // Create a new tuple shape based on the indices actually used by this + // computation branch. + std::vector new_tuple_shapes; + new_tuple_shapes.reserve(tuple_indices_to_keep.size()); + std::vector map(old_tuple_element_count, -1); + for (int64 i : tuple_indices_to_keep) { + map[i] = new_tuple_shapes.size(); + new_tuple_shapes.push_back(param->shape().tuple_shapes(i)); + } + Shape tuple_shape = ShapeUtil::MakeTupleShape(new_tuple_shapes); + // Clone the computation in case it is called by another non-conditional + // instruction. + HloComputation* new_computation = + computation->parent()->AddEmbeddedComputation(computation->Clone()); + param = new_computation->parameter_instruction(0); + // Reset the parameter shape of the computation. + *param->mutable_shape() = tuple_shape; - // Create a new tuple shape based on the indices actually used by this - // branch. - std::vector new_tuple_shapes; - new_tuple_shapes.reserve(tuple_indices_to_keep[branch].size()); - std::vector map(old_tuple_element_count, -1); - for (int64 i : tuple_indices_to_keep[branch]) { - map[i] = new_tuple_shapes.size(); - new_tuple_shapes.push_back(old_shape.tuple_shapes(i)); - } - Shape tuple_shape = ShapeUtil::MakeTupleShape(new_tuple_shapes); - // Reset the parameter shape of the computation. - *param->mutable_shape() = tuple_shape; + // Reroute the GTE instructions to new tuple indices. + for (HloInstruction* user : param->users()) { + user->set_tuple_index(map[user->tuple_index()]); + } - // Reroute the GTE instructions to new tuple indices. - for (HloInstruction* user : param->users()) { - user->set_tuple_index(map[user->tuple_index()]); - } + // Adjust the operand shape of all calling conditionals. + for (HloInstruction* conditional : calling_conditionals) { + // Avoid dealing with sharding. + if (conditional->has_sharding()) { + continue; } + for (int64 branch = 0; branch < conditional->branch_count(); ++branch) { + if (conditional->branch_computation(branch) != computation) { + continue; + } + conditional->set_branch_computation(branch, new_computation); + const Shape& old_shape = conditional->operand(branch + 1)->shape(); - // Reroute the operand tuple through a tuple of gte instructions of the - // original operand tuple. - const auto& to_keep = (*changed_computations)[computation]; - std::vector new_tuple_operands; - new_tuple_operands.reserve(to_keep.size()); - for (int64 i : to_keep) { - new_tuple_operands.push_back(conditional->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - old_shape.tuple_shapes(i), - conditional->mutable_operand(branch + 1), i))); + // Reroute the operand tuple through a tuple of gte instructions of the + // original operand tuple. + std::vector new_tuple_operands; + new_tuple_operands.reserve(tuple_indices_to_keep.size()); + for (int64 i : tuple_indices_to_keep) { + new_tuple_operands.push_back(conditional->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + old_shape.tuple_shapes(i), + conditional->mutable_operand(branch + 1), i))); + } + HloInstruction* new_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple(new_tuple_operands)); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWithDifferentShape(branch + 1, new_tuple)); + CHECK(ShapeUtil::Compatible(conditional->operand(branch + 1)->shape(), + conditional->branch_computation(branch) + ->parameter_instruction(0) + ->shape())); + CHECK(ShapeUtil::Compatible( + conditional->shape(), + conditional->branch_computation(branch)->root_instruction()->shape())) + << conditional->branch_computation(branch)->ToString(); } - HloInstruction* new_tuple = conditional->parent()->AddInstruction( - HloInstruction::CreateTuple(new_tuple_operands)); - TF_RETURN_IF_ERROR( - conditional->ReplaceOperandWithDifferentShape(branch + 1, new_tuple)); } return true; } @@ -333,7 +334,7 @@ bool RemoveUnusedTupleElements(HloInstruction* conditional_op) { } // Compute old-to-new (old-to-new) indices mapping. - std::map new_to_old_mapping, old_to_new_mapping; + absl::flat_hash_map new_to_old_mapping, old_to_new_mapping; auto old_iter = used_indices.begin(); for (int new_index = 0; new_index < new_tuple_shapes_size; ++new_index) { old_iter = std::find(old_iter, used_indices.end(), true); @@ -519,7 +520,8 @@ bool MergeDuplicateTupleElements(HloInstruction* conditional) { }; bool changed = false; - std::map, int64> index_collision_table; + absl::flat_hash_map, int64> + index_collision_table; for (int i = 0; i < conditional->shape().tuple_shapes_size(); ++i) { const std::vector ith_operands_vector = vectorize_branches_root_tuple_ith_operand(i); @@ -551,16 +553,34 @@ StatusOr ConditionalSimplifier::Run(HloModule* module) { } } - std::map> changed_computations; + absl::flat_hash_set removed_conditionals; for (HloInstruction* conditional_op : conditional_ops) { changed |= MergeDuplicateTupleElements(conditional_op); changed |= RemoveUnusedTupleElements(conditional_op); changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op); TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); - if (!result) { - TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands( - conditional_op, &changed_computations)); + if (result) { + removed_conditionals.insert(conditional_op); + changed = true; } + } + // Try to remove unused conditional operands from branch computations. We need + // to be careful to adjust *all* calling conditional ops if we do that, so + // lets collect them first. + absl::flat_hash_map> + calling_conditionals; + for (HloInstruction* conditional : conditional_ops) { + if (removed_conditionals.contains(conditional)) { + continue; + } + for (int64 branch = 0; branch < conditional->branch_count(); ++branch) { + calling_conditionals[conditional->branch_computation(branch)].insert( + conditional); + } + } + for (const auto& entry : calling_conditionals) { + TF_ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands( + entry.first, entry.second)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 8a7fba6a48f..ea3101fa0ed 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -198,32 +198,26 @@ ENTRY main { c1_1 = f32[40,40] parameter(3) p = pred[] parameter(4) t = (f32[20,40], f32[40,40], f32[20,40], f32[40,40]) tuple(c0_0, c0_1, c1_0, c1_1) - ROOT result = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true + call = (f32[20,40]) call(t), to_apply=on_true + ROOT result = (f32[20,40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true } )"; auto status = ParseAndReturnVerifiedModule(hlo_string); TF_ASSERT_OK(status.status()); + std::unique_ptr module = status.ConsumeValueOrDie(); HloVerifier v(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); - EXPECT_TRUE( - ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); - TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); - EXPECT_EQ(status.ValueOrDie() - ->entry_computation() - ->root_instruction() - ->operand(1) - ->shape() - .tuple_shapes() - .size(), - 2); - EXPECT_EQ(status.ValueOrDie() - ->entry_computation() - ->root_instruction() - ->operand(2) - ->shape() - .tuple_shapes() - .size(), - 2); + TF_ASSERT_OK(v.Run(module.get()).status()); + EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie()); + TF_ASSERT_OK(v.Run(module.get()).status()); + HloInstruction* conditional = module->entry_computation()->root_instruction(); + EXPECT_TRUE(conditional != nullptr); + EXPECT_EQ(conditional->operand(1)->shape().tuple_shapes().size(), 2); + EXPECT_EQ(conditional->operand(2)->shape().tuple_shapes().size(), 2); + // For the call operation, nothing should have changed. + HloInstruction* call = FindInstruction(module.get(), "call"); + EXPECT_EQ( + call->to_apply()->parameter_instruction(0)->shape().tuple_shapes().size(), + 4); } TEST_F(ConditionalSimplifierTest, diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 121bdedf2dd..3460e65b0a2 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -118,6 +118,9 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:LLVMDialect", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", @@ -146,6 +149,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_padder", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -365,6 +369,7 @@ cc_library( "@llvm-project//llvm:core", "@llvm-project//llvm:support", "@llvm-project//llvm:target", + "@llvm-project//mlir:IR", ], ) @@ -455,6 +460,7 @@ cc_library( ":cpu_options", ":cpu_runtime", ":ir_emission_utils", + ":mlir_emitter", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -473,6 +479,10 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:core", + "@llvm-project//mlir:EDSC", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:StandardOps", ], ) @@ -1069,3 +1079,24 @@ tf_cc_test( "@llvm-project//llvm:target", ], ) + +cc_library( + name = "mlir_emitter", + srcs = ["mlir_emitter.cc"], + hdrs = ["mlir_emitter.h"], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "@llvm-project//llvm:core", + "@llvm-project//llvm:ipo", + "@llvm-project//llvm:linker", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:LinalgToLLVM", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TargetLLVMIR", + "@llvm-project//mlir:VectorToLLVM", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index b04237138e8..b2416ac2799 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,6 +42,8 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -72,6 +74,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" +#include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -157,6 +160,8 @@ CpuCompiler::CpuCompiler() { // Initialize LLVM's MC layer for the native target. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerAllDialects(); } namespace { @@ -239,7 +244,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // Expand random number generation. pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); @@ -273,6 +277,13 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/false); + pipeline.AddPass(); + pipeline.AddPass( + /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(target_machine_features); { auto& pass = @@ -281,12 +292,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*allow_mixed_precision=*/false); pass.AddPass(); - pass.AddPass(); - pass.AddPass( - /*rewrite_training_op=*/true, - /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true); - pipeline.AddPass(); AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(false); pass.AddPass(options); @@ -605,9 +610,11 @@ StatusOr> CpuCompiler::RunBackend( user_post_optimization_hook_); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = absl::make_unique(); - auto llvm_module = - absl::make_unique("__compute_module", *llvm_context); + mlir::MLIRContext mlir_context; + auto llvm_module = absl::make_unique( + "__compute_module", + mlir_context.getRegisteredDialect() + ->getLLVMContext()); auto jit = absl::make_unique( CompilerTargetOptions(module->config()), @@ -661,7 +668,7 @@ StatusOr> CpuCompiler::RunBackend( // before a caller computation. LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, @@ -815,8 +822,11 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. - llvm::LLVMContext llvm_context; - llvm::Module llvm_module("__compute_module", llvm_context); + mlir::MLIRContext mlir_context; + llvm::Module llvm_module( + "__compute_module", + mlir_context.getRegisteredDialect() + ->getLLVMContext()); llvm_module.setDataLayout(target_machine->createDataLayout()); llvm_module.setTargetTriple(triple.getTriple()); if (pic_level != llvm::PICLevel::NotPIC) { @@ -865,7 +875,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, } LLVMTargetMachineFeatures target_machine_features(target_machine.get()); - IrEmitter ir_emitter(*module, *assignment, &llvm_module, + IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 8c1ae0179c0..f031daecb1f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -363,7 +363,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( if (shape.IsOpaque()) { return sizeof(void*); } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + if (shape.is_static() || shape.IsTuple()) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + // Each dynamic dimension size is represented as a S32. + int64 metadata_size = sizeof(int32) * shape.dimensions_size(); + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size; } const InstructionValueSet& CpuExecutable::GetRootValueSet() const { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index ff654c83d61..c0222010fd9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; +const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } +bool UseLinalgForDot(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaUseLinalgForDot) > 0; +} + static absl::string_view RemoveSuffix(absl::string_view str, absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 99e6702d14a..5d25aef6912 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,6 +27,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool UseLinalgForDot(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index bd949aa24c7..7abf5da0b64 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -67,6 +67,10 @@ extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kEigenMatMulC64SymbolName = + "__xla_cpu_runtime_EigenMatMulC64"; +extern const char* const kEigenMatMulC128SymbolName = + "__xla_cpu_runtime_EigenMatMulC128"; extern const char* const kEigenMatMulS32SymbolName = "__xla_cpu_runtime_EigenMatMulS32"; extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32"; @@ -91,6 +95,10 @@ extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF64"; +extern const char* const kEigenSingleThreadedMatMulC64SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulC64"; +extern const char* const kEigenSingleThreadedMatMulC128SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulC128"; extern const char* const kEigenSingleThreadedMatMulS32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulS32"; extern const char* const kEigenSingleThreadedConvF16SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 14ea5448eef..492ce3f68b2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -46,6 +46,8 @@ namespace runtime { extern const char* const kEigenMatMulF16SymbolName; extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kEigenMatMulC64SymbolName; +extern const char* const kEigenMatMulC128SymbolName; extern const char* const kEigenMatMulS32SymbolName; extern const char* const kMKLConvF32SymbolName; extern const char* const kMKLMatMulF32SymbolName; @@ -59,6 +61,8 @@ extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; +extern const char* const kEigenSingleThreadedMatMulC64SymbolName; +extern const char* const kEigenSingleThreadedMatMulC128SymbolName; extern const char* const kEigenSingleThreadedMatMulS32SymbolName; extern const char* const kEigenSingleThreadedConvF16SymbolName; extern const char* const kEigenSingleThreadedConvF32SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index fae9670051a..e21ed7ad60e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -154,7 +154,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %d bytes", size); + return InvalidArgument("CPU infeed of %d bytes exceeds maximum of %d bytes", + size, std::numeric_limits::max()); } if (size <= 0) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 7dba826b65c..9e75c1b9ac5 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,8 +23,17 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/EDSC/Builders.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -89,6 +98,9 @@ enum class DotImplementationStrategy { // and the output have to be row major. kTiledLlvmIrGemm, + // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR. + kLinalgMatmul, + // The dot operation is lowered into a call into an Eigen routine. No fusions // are supported today. The two inputs and the output have to be row major. // However, we do allow transposing either the LHS or the RHS as part of the @@ -112,7 +124,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -163,6 +175,9 @@ class DotOpEmitter { // Lowers the dot operation as a tiled Matrix*Matrix loop. void EmitTiledLlvmIrGemm(); + // Lowers the dot operation through MLIR's linalg.matmul. + Status EmitLinalgMatmul(); + // Lowers the dot operation as a naive nested loop that computes the result // one element at a time. void EmitNaiveLlvmIrGemm(); @@ -194,20 +209,19 @@ class DotOpEmitter { const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; + mlir::MLIRContext* mlir_context_; const HloModuleConfig& hlo_module_config_; const TargetMachineFeatures& target_machine_features_; }; } // namespace -DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter( + DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_info_(std::move(dot_info)), dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), @@ -216,9 +230,36 @@ DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), b_(b), + mlir_context_(mlir_context), hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} +Status DotOpEmitter::EmitLinalgMatmul() { + Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape}; + llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(), + rhs_array_.GetBasePointer()}; + llvm::Value* target_ptr = target_array_.GetBasePointer(); + + // Zero out the output buffer. + int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape); + b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/llvm::MaybeAlign(1)); + + std::string name = + absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_", + dot_info_.lhs_shape.ToString(true), "_", + dot_info_.rhs_shape.ToString(true)); + return EmitMlirFuncAndCall( + mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr, + operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) { + mlir::edsc::ScopedContext scope(*builder, function.getLoc()); + mlir::Value a = function.getArgument(0), b = function.getArgument(1), + c = function.getArgument(2); + mlir::edsc::intrinsics::linalg_matmul(b, c, a); + mlir::edsc::intrinsics::std_ret(); + }); +} + void DotOpEmitter::EmitTiledLlvmIrGemm() { PrimitiveType primitive_type = dot_info_.result_shape.element_type(); MatMultDims mat_mult_dims = GetMatMultDims(); @@ -418,6 +459,9 @@ Status DotOpEmitter::Emit() { EmitTiledLlvmIrGemm(); return Status::OK(); + case DotImplementationStrategy::kLinalgMatmul: + return EmitLinalgMatmul(); + case DotImplementationStrategy::kEigen: return EmitCallToRuntime(); } @@ -613,6 +657,8 @@ Status DotOpEmitter::EmitCallToRuntime() { bool multi_threaded = ShouldUseMultiThreadedEigen(hlo_module_config_); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); + llvm::Function* function = b_->GetInsertBlock()->getParent(); + llvm::Module* module = function->getParent(); llvm::Type* float_type; const char* fn_name; switch (type) { @@ -640,6 +686,18 @@ Status DotOpEmitter::EmitCallToRuntime() { : runtime::kEigenSingleThreadedMatMulF64SymbolName); float_type = b_->getDoubleTy(); break; + case C64: + fn_name = multi_threaded + ? runtime::kEigenMatMulC64SymbolName + : runtime::kEigenSingleThreadedMatMulC64SymbolName; + float_type = llvm_ir::PrimitiveTypeToIrType(C64, module); + break; + case C128: + fn_name = multi_threaded + ? runtime::kEigenMatMulC128SymbolName + : runtime::kEigenSingleThreadedMatMulC128SymbolName; + float_type = llvm_ir::PrimitiveTypeToIrType(C128, module); + break; case S32: fn_name = multi_threaded ? runtime::kEigenMatMulS32SymbolName @@ -661,9 +719,6 @@ Status DotOpEmitter::EmitCallToRuntime() { int64_type, int64_type, int64_type, int32_type, int32_type}, /*isVarArg=*/false); - llvm::Function* function = b_->GetInsertBlock()->getParent(); - llvm::Module* module = function->getParent(); - llvm::FunctionCallee matmul_func = module->getOrInsertFunction(fn_name, matmul_type); if (auto* fn = llvm::dyn_cast(matmul_func.getCallee())) { @@ -809,9 +864,11 @@ bool AreGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, << output_shape.DebugString(); switch (output_shape.element_type()) { - case F64: - case F32: case F16: + case F32: + case F64: + case C64: + case C128: case S32: return IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape); default: @@ -860,7 +917,9 @@ bool CanEmitTiledLlvmIrGemm( return false; } - if (dot_info.result_shape.element_type() == F16) { + if (dot_info.result_shape.element_type() == F16 || + dot_info.result_shape.element_type() == C64 || + dot_info.result_shape.element_type() == C128) { // TODO(sanjoy): This is probably easy to fix, but I want to keep the CL // adding this comment NFC. return false; @@ -886,9 +945,12 @@ DotImplementationStrategy GetDotImplementationStrategy( } if (IsAlignedGemm(dot_info, target_machine_features)) { - return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) - ? DotImplementationStrategy::kTiledLlvmIrGemm - : DotImplementationStrategy::kEigen; + if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) { + return options::UseLinalgForDot(config) + ? DotImplementationStrategy::kLinalgMatmul + : DotImplementationStrategy::kTiledLlvmIrGemm; + } + return DotImplementationStrategy::kEigen; } return DotImplementationStrategy::kNaiveLlvmIr; @@ -899,15 +961,15 @@ Status EmitNonBatchDotOperation( const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type || C64 == type || C128 == type); DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, hlo_module_config, - target_machine_features); + executable_run_options_value, b, mlir_context, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } @@ -981,7 +1043,7 @@ Status EmitBatchDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers())); @@ -1039,7 +1101,7 @@ Status EmitBatchDotOperation( // Emit the inner non-batch dot operation. return EmitNonBatchDotOperation( dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr, - executable_run_options_value, b, hlo_module_config, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); }); } @@ -1089,7 +1151,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { // This routine assumes that the dot operation is not in a parallelized @@ -1099,13 +1161,13 @@ Status EmitDotOperation(const HloInstruction& dot, if (IsBatchDot(dot)) { TF_RET_CHECK(addend_array == nullptr); return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 105bd3005c8..d9cf8a2036b 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -63,7 +64,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index e21ca01c803..05364a4492b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -109,24 +109,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { - case HloOpcode::kMap: - return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - std::vector operands; - for (int i = 0; i < hlo->operand_count(); i++) { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(i))(index)); - operands.push_back(operand_value); - } - return ir_emitter_->EmitElementalMap(*Cast(hlo), - operands, llvm_ir::IrName(hlo)); - }; - case HloOpcode::kReduceWindow: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { - return ir_emitter_->EmitElementalReduceWindow( - Cast(hlo), - operand_to_generator.at(hlo->operand(0)), index); - }; case HloOpcode::kConvolution: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { return ir_emitter_->EmitElementalConvolution( @@ -134,22 +116,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(hlo->operand(0)), operand_to_generator.at(hlo->operand(1)), index); }; - case HloOpcode::kReduce: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { - auto reduce_instr = Cast(hlo); - std::vector input_generators; - for (const HloInstruction* instr : reduce_instr->inputs()) { - input_generators.push_back(operand_to_generator.at(instr)); - } - - std::vector initial_value_generators; - for (const HloInstruction* instr : reduce_instr->init_values()) { - initial_value_generators.push_back(operand_to_generator.at(instr)); - } - return ir_emitter_->EmitElementalReduce( - reduce_instr, std::move(input_generators), - std::move(initial_value_generators), index); - }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index e3fba9306b7..5c9f6677ab3 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -44,6 +44,12 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) override { + return ir_emitter_->EmitThreadLocalCall(callee, parameters, name); + } + IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index c19fa779b60..998b9db132c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include + #include #include #include @@ -40,6 +41,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -88,8 +90,8 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { IrEmitter::IrEmitter( - const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, const TargetMachineFeatures* target_machine_features, @@ -98,6 +100,7 @@ IrEmitter::IrEmitter( module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), b_(llvm_module->getContext()), + mlir_context_(mlir_context), instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), @@ -570,25 +573,9 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); PrimitiveType keys_type = keys_shape.element_type(); - switch (keys_type) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case BF16: - case F16: - case S32: - case U32: - case F32: - case S64: - case U64: - case F64: - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); + if (!primitive_util::IsArrayType(keys_type)) { + return Unimplemented("Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { @@ -695,101 +682,6 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -llvm::Value* IrEmitter::EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, absl::string_view name) { - return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(), - elemental_operands, name); -} - -StatusOr IrEmitter::EmitElementalReduceWindow( - const HloReduceWindowInstruction* reduce_window, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::IrArray::Index& index) { - const HloInstruction* operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - - // We fold inputs into the accumulator and initialize it to - // the initial value on the reduce_window. - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accumulator_address", &b_, - MinimumAlignmentForPrimitiveType(operand_element_type)); - Store(Load(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); - - llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - std::vector input_multi_index(index.size()); - llvm::Value* in_bounds_condition = nullptr; - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* strided_index = - NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_multi_index[i] = NSWSub( - NSWAdd(strided_index, - NSWMul(window_index[i], - b_.getInt64(window.dimensions(i).window_dilation()))), - b_.getInt64(window.dimensions(i).padding_low())); - - // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = - ICmpEQ(SRem(input_multi_index[i], - b_.getInt64(window.dimensions(i).base_dilation())), - b_.getInt64(0)); - if (in_bounds_condition == nullptr) { - in_bounds_condition = dilation_condition; - } else { - in_bounds_condition = And(in_bounds_condition, dilation_condition); - } - - // Apply base dilation to the index. - input_multi_index[i] = - SDiv(input_multi_index[i], - b_.getInt64(window.dimensions(i).base_dilation())); - - // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we - // are in the padding so that we can skip the computation. That is - // equivalent to input_multi_index[i] < bound as an *unsigned* comparison, - // since a negative value will wrap to a large positive value. - llvm::Value* index_condition = - ICmpULT(input_multi_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - if (in_bounds_condition == nullptr) { - in_bounds_condition = index_condition; - } else { - in_bounds_condition = And(in_bounds_condition, index_condition); - } - } - CHECK(in_bounds_condition != nullptr); - - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); - SetToFirstInsertPoint(if_data.true_block, &b_); - - // We are not in the padding, so carry out the computation. - llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(), - b_.getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, - input_generator(input_index)); - llvm::Value* result = EmitScalarReturningThreadLocalCall( - *reduce_window->to_apply(), {Load(accumulator_address), input_value}, - "reducer_function"); - Store(result, accumulator_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(accumulator_address); -} - Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // Pseudo code for reduce window: // @@ -1008,7 +900,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, hlo_module_config_, target_machine_features_); } @@ -1325,7 +1217,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { auto operand = fft->operand(0); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*fft, /*operands=*/{operand}, - /*supported_types=*/{F32, C64})); + /*supported_types=*/{F32, F64, C64, C128})); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape()); @@ -1347,7 +1239,7 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { llvm::FunctionType* fft_type = llvm::FunctionType::get( b_.getVoidTy(), {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, - int64_type, int64_type, int64_type, int64_type}, + int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); bool multi_threaded_eigen = @@ -1366,6 +1258,8 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {GetExecutableRunOptionsArgument(), BitCast(GetEmittedValueFor(fft), int8_ptr_type), BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(operand->shape().element_type() == F64 || + operand->shape().element_type() == C128), b_.getInt32(fft_rank), b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), @@ -2099,108 +1993,6 @@ StatusOr IrEmitter::EmitVectorizedReduce( return true; } -StatusOr IrEmitter::EmitElementalReduce( - const HloReduceInstruction* reduce, - std::vector input_generators, - std::vector initial_value_generators, - const llvm_ir::IrArray::Index& index) { - const Shape& out_shape = reduce->shape(); - bool is_variadic = !out_shape.IsArray(); - int accumulators_count = 1; - if (is_variadic) { - CHECK(out_shape.IsTuple()); - accumulators_count = out_shape.tuple_shapes_size(); - } - - absl::Span reduced_dimensions(reduce->dimensions()); - - std::vector accumulator_addrs; - std::vector accumulator_types; - for (int i = 0; i < accumulators_count; i++) { - const Shape& element_shape = - is_variadic ? out_shape.tuple_shapes(i) : out_shape; - PrimitiveType accumulator_type = element_shape.element_type(); - llvm::Type* accumulator_llvm_type = - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); - accumulator_types.push_back(accumulator_llvm_type); - - // Initialize an accumulator with init_value. - llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_, - MinimumAlignmentForPrimitiveType(accumulator_type)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_value, - initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType()))); - Store(init_value, accumulator_addr); - accumulator_addrs.push_back(accumulator_addr); - } - - // The enclosing loops go over all the target elements. Now we have to compute - // the actual target element. For this, we build a new loop nest to iterate - // over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction Value*s - // are placed for each dimension in dimensions, and all the rest are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - const HloInstruction* arg = reduce->operand(0); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - // Build a full index for the input argument, using input_multi_index as the - // base. In input_multi_index only the reduction dimensions are filled in. We - // fill in the rest of the dimensions with induction Value*s taken from - // 'index' which iterates over the target array. See the high-level - // description in the XLA documentation for details. - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = *it++; - } - } - CHECK(index.end() == it); - llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), - b_.getInt64Ty()); - - std::vector reduction_operands; - for (llvm::Value* accum : accumulator_addrs) { - llvm::Value* accum_value = Load(accum); - reduction_operands.push_back(accum_value); - } - - for (int i = 0; i < accumulators_count; i++) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, - input_generators[i](input_index)); - reduction_operands.push_back(input_element); - } - - std::vector results = EmitThreadLocalCall( - *reduce->to_apply(), reduction_operands, "reduce_function"); - - CHECK(results.size() == accumulators_count); - for (int i = 0; i < accumulators_count; i++) { - Store(results[i], accumulator_addrs[i]); - } - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - - if (is_variadic) { - // Emit a structure, as that what the LoopEmitter expects. - llvm::Value* returned_structure = llvm::UndefValue::get( - llvm::StructType::get(b_.getContext(), accumulator_types)); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* accumulator_value = Load(accumulator_addrs[i]); - returned_structure = - b_.CreateInsertValue(returned_structure, accumulator_value, i); - } - return returned_structure; - } else { - CHECK_EQ(accumulator_addrs.size(), 1); - return Load(accumulator_addrs[0]); - } -} - Status IrEmitter::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); @@ -2517,10 +2309,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR( - EmitDotOperation(*dot, target_array, lhs_array, rhs_array, - &addend_array, GetExecutableRunOptionsArgument(), &b_, - hlo_module_config_, target_machine_features_)); + TF_RETURN_IF_ERROR(EmitDotOperation( + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2554,7 +2346,125 @@ Status IrEmitter::HandleCall(HloInstruction* call) { return Status::OK(); } +Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + std::vector dynamic_dims; + int32 raw_data_size = + ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape())); + llvm::Value* dest_buffer = GetEmittedValueFor(hlo); + llvm::Value* raw_buffer = + b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo()); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + const int64 dim_index = i - 1; + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i)); + llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size"); + + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + b_.CreateStore(dyn_dim_size, + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(), + /*isSigned=*/true, + "i64_dyn_dim_size")); + } + + llvm_ir::IrArray data_array = GetIrArrayFor(hlo); + // Pseudo code for sliceToDynamic: + // + // for (index i in dynamic_dim) + // dest_index = delinearize(linearize(i, dynamic_dim), static_dim) + // dest[dest_index] = source[i] + auto loop_body_emitter = + [&](const llvm_ir::IrArray::Index& array_index) -> Status { + llvm::Value* source_element = + GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_); + llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_); + // Delinearize the index based on the static shape. + llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(), + &b_); + data_array.EmitWriteArrayElement(dest_index, source_element, &b_); + return Status::OK(); + }; + return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(), + dynamic_dims, &b_) + .EmitLoop(IrName(hlo)); +} + +Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(hlo, {0})); + std::vector dynamic_dims; + std::vector tuple_operand_ptrs; + const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0}); + const Shape& input_shape = hlo->operand(0)->shape(); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); + llvm_ir::IrArray data_array(data_address, data_shape); + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0)); + llvm::Value* raw_buffer = + b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo()); + int64 raw_data_size = + ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape)); + + // Put a placeholder for the data array's pointer + tuple_operand_ptrs.push_back(data_array.GetBasePointer()); + // PadToStatic has a dynamic tensor as input and variadic size of outputs: + // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... ) + // Dynamic dimension sizes starts from output index 1. + for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) { + // Read from the metadata section of the dynamic input (operand 0). + const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i}); + TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice, + assignment_.GetUniqueSlice(hlo, {i})); + llvm::Value* dest_dim_size_address = + EmitBufferPointer(dim_size_slice, data_shape); + const int64 dim_index = i - 1; + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + llvm::Value* dyn_dim_size = b_.CreateLoad( + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()), + "dyn_dim_size"); + b_.CreateStore(dyn_dim_size, + b_.CreateBitCast(dest_dim_size_address, + b_.getInt32Ty()->getPointerTo())); + dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(), + /*isSigned=*/true, + "i64_dyn_dim_size")); + tuple_operand_ptrs.push_back(dest_dim_size_address); + } + + // Pseudo code for padToStatic: + // + // for (index i in dynamic_dim) + // source_index = delinearize(inearize(i, dynamic_dim), static_dim) + // dest[i] = source[source_index] + auto loop_body_emitter = + [&](const llvm_ir::IrArray::Index& array_index) -> Status { + llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_); + llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_); + llvm::Value* source_element = + GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_); + data_array.EmitWriteArrayElement(array_index, source_element, &b_); + return Status::OK(); + }; + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_) + .EmitLoop(IrName(hlo))); + + // Emit static tensor and dynamic sizes as one tuple. + llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_); + return Status::OK(); +} + Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { + if (custom_call->custom_call_target() == "PadToStatic") { + return HandlePadToStatic(custom_call); + } + if (custom_call->custom_call_target() == "SliceToDynamic") { + return HandleSliceToDynamic(custom_call); + } absl::Span operands(custom_call->operands()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = @@ -2999,9 +2909,8 @@ Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) { old_state->getType()->getScalarType(), address->getType()->getPointerAddressSpace())); llvm::StoreInst* store = Store(old_state, address); - store->setAlignment( - llvm::MaybeAlign(IrEmitter::MinimumAlignmentForPrimitiveType( - rng_state->shape().element_type()))); + store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType( + rng_state->shape().element_type()))); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index cc5aa3f37fc..661785153d0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ #include + #include #include #include @@ -32,6 +33,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -58,6 +60,8 @@ namespace cpu { // functions. class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { + friend class CpuElementalIrEmitter; + public: using GeneratorForOperandIrArrays = std::function()>; @@ -67,14 +71,16 @@ class IrEmitter : public DfsHloVisitorWithDefault, // hlo_module: the HLO module we are emitting IR for. // assignment: a BufferAssignment from which we know which buffers are used by // the HLO nodes. - // llvm_module: the LLVM module to emit IR into. + // mlir_context: the MLIR context used for IR emission. + // llvm_module: the LLVM module to emit IR into. It's built using the LLVM + // context inside of mlir_context. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. // emit_code_for_msan: whether emitted code should be compatible with msan. - IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map @@ -113,28 +119,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); - // Emit code to map one element according to `map_instr`. - llvm::Value* EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, - absl::string_view name); - // Emit code to emit the element at `index` for a reduce window instruction. - StatusOr EmitElementalReduceWindow( - const HloReduceWindowInstruction* reduce_window, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::IrArray::Index& index); // Emit code to emit the element at `index` for a convolution instruction. StatusOr EmitElementalConvolution( const HloConvolutionInstruction* convolution, const llvm_ir::ElementGenerator& input_generator, const llvm_ir::ElementGenerator& kernel_generator, const llvm_ir::IrArray::Index& index); - // Emit code to emit the element at `index` for a reduce instruction. - StatusOr EmitElementalReduce( - const HloReduceInstruction* reduce, - std::vector input_generators, - std::vector initial_value_generator, - const llvm_ir::IrArray::Index& index); protected: // @@ -197,6 +187,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, } private: + Status HandleSliceToDynamic(HloInstruction* hlo); + Status HandlePadToStatic(HloInstruction* hlo); Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); @@ -454,6 +446,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // module's function list). std::unique_ptr compute_function_; llvm::IRBuilder<> b_; + mlir::MLIRContext* mlir_context_; // The buffer allocation slice for the root of the computation being compiled. // Only relevant for thread local computations. diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc new file mode 100644 index 00000000000..e7d52c288d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" + +#include "llvm/Linker/Linker.h" +#include "llvm/Transforms/IPO/Internalize.h" +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" + +namespace xla { +namespace cpu { +namespace { + +// Lower an MLIR module to an LLVM module. +std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { + mlir::PassManager manager(module->getContext()); + manager.addPass(mlir::createConvertLinalgToLoopsPass()); + manager.addPass(mlir::createConvertLinalgToLLVMPass()); + manager.addPass(mlir::createConvertVectorToLLVMPass()); + manager.addPass(mlir::createLowerToLLVMPass()); + CHECK(succeeded(manager.run(*module))); + return mlir::translateModuleToLLVMIR(*module); +} + +// Get arguments to pass a memref to an mlir function. +void BuildViewForBuffer(llvm::SmallVectorImpl *args, + llvm::IRBuilder<> *b, const Shape &opShape, + llvm::Value *op_val) { + llvm::Type *ty = op_val->getType(); + while (auto aty = llvm::dyn_cast( + llvm::cast(ty)->getElementType())) { + ty = aty->getElementType()->getPointerTo(); + } + op_val = b->CreateBitCast(op_val, ty); + + args->push_back(op_val); // Allocated pointer. + args->push_back(op_val); // Aligned pointer. + args->push_back(b->getInt64(0)); // Offset. + + // Sizes. + for (int64 dim : opShape.dimensions()) { + args->push_back(b->getInt64(dim)); + } + + int64_t accumulated_stride = 1; + llvm::SmallVector strides(opShape.rank(), 1); + for (int64 dim : LayoutUtil::MinorToMajor(opShape)) { + strides[dim] = accumulated_stride; + accumulated_stride *= opShape.dimensions(dim); + } + + // Strides. + for (int64 stride : strides) { + args->push_back(b->getInt64(stride)); + } +} +} // namespace + +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef operand_ptrs, llvm::StringRef func_name, + llvm::function_ref emitter) { + llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent(); + mlir::Builder mlir_builder(context); + + // Get memref types for the inputs and output. + TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType( + result_shape, mlir_builder)); + std::vector operand_types = {ret_memref}; + for (int i = 0; i != operand_shapes.size(); ++i) { + TF_ASSIGN_OR_RETURN( + mlir::Type op_memref, + ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder)); + operand_types.push_back(op_memref); + } + + // Create the function an call the emission callback. + mlir::Location loc = mlir::UnknownLoc::get(context); + auto function = mlir::FuncOp::create( + loc, func_name, mlir::FunctionType::get(operand_types, {}, context)); + function.addEntryBlock(); + mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc); + mlir_module->push_back(function); + mlir::OpBuilder op_builder(&function.getBody()); + emitter(&op_builder, function); + + // Now link it all into the main LLVM module. + auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module)); + mlir_llvm_module->setDataLayout(llvm_module->getDataLayout()); + llvm::Linker::linkModules( + *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None, + [](llvm::Module &M, const llvm::StringSet<> &GVS) { + llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) { + return !GV.hasName() || (GVS.count(GV.getName()) == 0); + }); + }); + + // And leave behind a call to the function generated by MLIR. + llvm::Function *func = llvm_module->getFunction(func_name); + llvm::SmallVector op_vals; + BuildViewForBuffer(&op_vals, b, result_shape, result_ptr); + for (int i = 0; i != operand_shapes.size(); ++i) { + BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]); + } + b->CreateCall(func, op_vals); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.h b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h new file mode 100644 index 00000000000..bc0741e851a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace cpu { + +// Create a new MLIR function with the name `func_name`, populate it with +// `emitter` and create a call, passing it the buffers defined by +// resultShape/resultPtr and operandShapes/operandPtrs. The function is added to +// the LLVM module at `b`s insertion point. +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef operand_ptrs, llvm::StringRef func_name, + llvm::function_ref emitter); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 14afe770ede..225102e6ae6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -142,24 +142,29 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // in-place will only touch the updated elements). // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); - if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || - opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || - opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || - opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || - opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || - opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || - opcode == HloOpcode::kSort || - (opcode == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction, - target_machine_features_)) || - (opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) || - llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || - instruction->shape().IsTuple()) { + if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || + instruction->shape().IsTuple() || opcode == HloOpcode::kRng) { return 1; } - // Consult 'cost_model_' to compute target parallel task count. - return cost_model_->GetParallelTaskCount(instruction); + // Only allow known good instructions. + if (instruction->IsElementwise() || instruction->IsLoopFusion() || + opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate || + opcode == HloOpcode::kDynamicSlice || + opcode == HloOpcode::kDynamicUpdateSlice || + opcode == HloOpcode::kGather || opcode == HloOpcode::kIota || + opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape || + opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice || + opcode == HloOpcode::kTranspose || + (opcode == HloOpcode::kConvolution && + !PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_))) { + // Consult 'cost_model_' to compute target parallel task count. + return cost_model_->GetParallelTaskCount(instruction); + } + + return 1; } StatusOr ParallelTaskAssigner::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index e2c93568b74..e22210a61f2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -170,5 +170,26 @@ TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) { EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) { + constexpr char hlo_string[] = R"( + HloModule TestTaskParallel_allreduce + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + input = f32[1234567] parameter(0) + ROOT crs = f32[1234567] all-reduce(input), replica_groups={}, to_apply=add + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc index 051120be324..0c1e9dae751 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc @@ -28,13 +28,14 @@ using tensorflow::int64; TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft( const void* run_options_ptr, void* out, void* operand, int32 fft_type, - int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, - int64 fft_length2) { + int32 double_precision, int32 fft_rank, int64 input_batch, + int64 fft_length0, int64 fft_length1, int64 fft_length2) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); tensorflow::xla::EigenFftImpl( *run_options->intra_op_thread_pool(), out, operand, - static_cast(fft_type), fft_rank, input_batch, - fft_length0, fft_length1, fft_length2); + static_cast(fft_type), + static_cast(double_precision), fft_rank, input_batch, fft_length0, + fft_length1, fft_length2); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_fft.h index f20c5aa0aa2..d95da172116 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.h @@ -22,7 +22,8 @@ extern "C" { extern void __xla_cpu_runtime_EigenFft( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, - void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + void* operand, tensorflow::int32 fft_type, + tensorflow::int32 double_precision, tensorflow::int32 fft_rank, tensorflow::int64 input_batch, tensorflow::int64 fft_length0, tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 04dea120a8d..124e7d589a0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -39,8 +39,8 @@ static constexpr int kFftTypeArraySize = 4; namespace internal { // Computes either a forward or reverse complex-to-complex FFT. -template -void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, +template +void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { // Create the axes (which are always trailing). @@ -55,10 +55,10 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, for (int i = 0; i < FFTRank; i++) { dims[i + 1] = fft_shape[i]; } - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> input(operand, dims); - Eigen::TensorMap, + Eigen::TensorMap, Eigen::Aligned> output(out, dims); output.device(device) = input.template fft(axes); @@ -66,8 +66,8 @@ void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, // Computes a forward real->complex FFT, slicing out redundant negative // frequencies from the innermost dimension. -template -void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, +template +void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { const std::array fft_shape = { @@ -81,10 +81,10 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; } - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> input(operand, in_dims); - Eigen::TensorMap, + Eigen::TensorMap, Eigen::Aligned> output(out, out_dims); @@ -92,7 +92,7 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Eigen::Tensor full_fft(in_dims); + Eigen::Tensor full_fft(in_dims); const Eigen::DSizes zero_start_indices; full_fft.device(device) = @@ -105,8 +105,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, // Computes a reverse complex->real FFT, reconstructing redundant negative // frequencies using reverse conjugate on innermost dimension after doing IFFT // on outer dimensions. -template -void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, +template +void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { const std::array fft_shape = { @@ -120,10 +120,10 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; } - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> input(operand, in_dims); - Eigen::TensorMap, + Eigen::TensorMap, Eigen::Aligned> output(out, out_dims); @@ -131,7 +131,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Eigen::Tensor full_fft(out_dims); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -178,30 +178,59 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, - FftType fft_type, int64 input_batch, int64 fft_length0, - int64 fft_length1, int64 fft_length2) { + FftType fft_type, bool double_precision, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { switch (fft_type) { case FftType::FFT: - EigenFftC2C( - device, static_cast(out), - static_cast(operand), input_batch, fft_length0, - fft_length1, fft_length2); + if (double_precision) { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } break; case FftType::IFFT: - EigenFftC2C( - device, static_cast(out), - static_cast(operand), input_batch, fft_length0, - fft_length1, fft_length2); + if (double_precision) { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } break; case FftType::RFFT: - EigenFftR2C( - device, static_cast(out), static_cast(operand), - input_batch, fft_length0, fft_length1, fft_length2); + if (double_precision) { + EigenFftR2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftR2C( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + } break; case FftType::IRFFT: - EigenFftC2R( - device, static_cast(out), static_cast(operand), - input_batch, fft_length0, fft_length1, fft_length2); + if (double_precision) { + EigenFftC2R( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + } else { + EigenFftC2R( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + } break; default: // Unsupported FFT type @@ -213,22 +242,24 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, template void EigenFftImpl(const EigenDevice& device, void* out, void* operand, - FftType fft_type, int32 fft_rank, int64 input_batch, - int64 fft_length0, int64 fft_length1, int64 fft_length2) { + FftType fft_type, bool double_precision, int32 fft_rank, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { switch (fft_rank) { case 1: - internal::EigenFftWithRank<1, EigenDevice>( - device, out, operand, fft_type, input_batch, fft_length0, 0, 0); + internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type, + double_precision, input_batch, + fft_length0, 0, 0); break; case 2: internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type, - input_batch, fft_length0, - fft_length1, 0); + double_precision, input_batch, + fft_length0, fft_length1, 0); break; case 3: - internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type, - input_batch, fft_length0, - fft_length1, fft_length2); + internal::EigenFftWithRank<3, EigenDevice>( + device, out, operand, fft_type, double_precision, input_batch, + fft_length0, fft_length1, fft_length2); break; default: // Unsupported FFT rank diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 7831c1b1b5b..0d4e7055ddb 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -60,6 +60,11 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); for (int64 index = 0; index < num_iteration_elements; ++index) { + // If the sort should be stable, we have to reinitialize indices to iota to + // guarantee that we still keep the relative order in case of ties. + if (is_stable && index > 0) { + std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); + } // 'index' can be split into two values which index into the 'c' dimension // and the 'a' dimension, respectively. 'index' % 'c' is the index into the // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 7d6c4942b69..35db15fed2c 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -114,6 +114,22 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( transpose_rhs); } +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, int32 transpose_rhs) { + MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, int32 transpose_rhs) { + MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32( const void* run_options_ptr, int32* out, int32* lhs, int32* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h index 1280d04d01f..11dfc5c1d80 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ +#include + #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" @@ -44,6 +46,18 @@ extern void __xla_cpu_runtime_EigenMatMulF64( tensorflow::int64 k, tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_EigenMatMulC64( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +extern void __xla_cpu_runtime_EigenMatMulC128( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenMatMulS32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, tensorflow::int32* out, tensorflow::int32* lhs, tensorflow::int32* rhs, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc index d2780dd694e..9476dce5ced 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -24,10 +24,11 @@ using tensorflow::int64; TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( const void* run_options_ptr, void* out, void* operand, int32 fft_type, - int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, - int64 fft_length2) { + int32 double_precision, int32 fft_rank, int64 input_batch, + int64 fft_length0, int64 fft_length1, int64 fft_length2) { tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, static_cast(fft_type), - fft_rank, input_batch, fft_length0, fft_length1, + static_cast(double_precision), fft_rank, + input_batch, fft_length0, fft_length1, fft_length2); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h index dcd133d012c..2f0ccda2d10 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -22,7 +22,8 @@ extern "C" { extern void __xla_cpu_runtime_EigenSingleThreadedFft( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, - void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + void* operand, tensorflow::int32 fft_type, + tensorflow::int32 double_precision, tensorflow::int32 fft_rank, tensorflow::int64 input_batch, tensorflow::int64 fft_length0, tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index e395bc7426c..c7601f939c7 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -112,6 +112,24 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, transpose_lhs, transpose_rhs); } +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulC64( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, int32 transpose_rhs) { + SingleThreadedMatMulDispatch>( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulC128( + const void* run_options_ptr, std::complex* out, + std::complex* lhs, std::complex* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, int32 transpose_rhs) { + SingleThreadedMatMulDispatch>( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} + TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr, int32* out, int32* lhs, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h index eb695910729..61fe224d420 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ +#include + #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" @@ -44,6 +46,20 @@ extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( tensorflow::int64 k, tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulC64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + std::complex* out, std::complex* lhs, + std::complex* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulC128( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + std::complex* out, std::complex* lhs, + std::complex* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenSingleThreadedMatMulS32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, tensorflow::int32* out, tensorflow::int32* lhs, tensorflow::int32* rhs, diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 395eb31c13f..4cc9e373b3e 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -246,6 +246,8 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulC128); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32); REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64); @@ -257,6 +259,8 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulC128); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32); REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index f52de3394fe..1ac8509cdb1 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -35,6 +35,19 @@ cc_library( ], ) +tf_cc_test( + name = "cpu_dyn_shape_test", + srcs = ["cpu_dyn_shape_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "cpu_fusion_test", srcs = ["cpu_fusion_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc new file mode 100644 index 00000000000..46249caa0c7 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" + +namespace xla { +namespace cpu { +namespace { + +using CpuDynamicShapeTest = CpuCodegenTest; + +TEST_F(CpuDynamicShapeTest, DynamicShapeR2) { + HloComputation::Builder builder(TestName()); + + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_input_shape.set_dynamic_dimension(0, true); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, dyn_input_shape, "x")); + + builder.AddInstruction(HloInstruction::CreateUnary( + dyn_input_shape, HloOpcode::kNegate, param_x)); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEntryComputation(builder.Build()); + + string filecheck_pattern = R"( +; CHECK: %[[dyn_dim_size:.*]] = load i32, i32* +; CHECK: %[[i64_dyn_dim_size:.*]] = sext i32 %[[dyn_dim_size:.*]] to i64 +; CHECK: icmp uge i64 %[[custom:.*]], %[[i64_dyn_dim_size:.*]] +; CHECK: %[[multiplier:.*]] = mul i64 1, %[[i64_dyn_dim_size:.*]] +; CHECK: mul nuw nsw i64 %[[custom:.*]], %[[multiplier:.*]] +)"; + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, + filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index b6d6de28bc5..efeab3bd31a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -70,6 +70,13 @@ class CpuUnaryIntrinsicTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; // Creates a module with a call to the unary op, and tests if the diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc index 8a72eb15487..757d878e224 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -69,6 +69,13 @@ class CpuVectorizationTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; TEST_P(CpuVectorizationTest, DoIt) { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index cadea620ec6..bdaac32a0e5 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -116,9 +116,12 @@ class DfsHloVisitorBase { virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; + virtual Status HandleAllGather(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0; virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index baa9240fb56..b1d674fe467 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -98,6 +98,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCholesky(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleAllGather(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } @@ -107,6 +110,12 @@ class DfsHloVisitorWithDefaultBase Status HandleCollectivePermute(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleReplicaId(HloInstructionPtr hlo) override { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc new file mode 100644 index 00000000000..fcdf85d5ecb --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc @@ -0,0 +1,139 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace dot_as_convolution_util { + +/* static */ absl::optional +ParseDotGeneralFromConvolution(const HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); + if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) { + return absl::nullopt; + } + const auto& conv_dims = conv->convolution_dimension_numbers(); + DotGeneralAsConvolutionDimsInfo dims; + dims.lhs_non_contracting_dims.push_back( + {conv_dims.input_batch_dimension(), -1, + conv_dims.output_batch_dimension(), -1}); + dims.rhs_non_contracting_dims.push_back( + {-1, conv_dims.kernel_output_feature_dimension(), + conv_dims.output_feature_dimension(), -1}); + dims.contracting_dims.push_back({conv_dims.input_feature_dimension(), + conv_dims.kernel_input_feature_dimension(), + -1, -1}); + + for (int64 i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) { + int64 lhs = conv_dims.input_spatial_dimensions(i); + int64 lhs_size = conv->operand(0)->shape().dimensions(lhs); + int64 rhs = conv_dims.kernel_spatial_dimensions(i); + int64 rhs_size = conv->operand(1)->shape().dimensions(rhs); + int64 output = conv_dims.output_spatial_dimensions(i); + const auto& wd = conv->window().dimensions(i); + if (lhs_size == wd.size() && + std::max(1, lhs_size - 1) == wd.stride() && + lhs_size == wd.base_dilation() && wd.window_dilation() == 1 && + wd.padding_high() == 0 && wd.padding_low() == 0 && + !wd.window_reversal()) { + // A batch dimension in DotGeneral is represented as a spatial dimension + // with window size B (batch dimension size), stride B - 1, and base + // dilation B. + dims.batch_dims.push_back({lhs, rhs, output, i}); + } else if (lhs_size == wd.size() && wd.base_dilation() == 1 && + wd.window_dilation() == 1 && wd.padding_high() == 0 && + wd.padding_low() == 0 && !wd.window_reversal()) { + // A contracting dimension be represented as a spatial dimension with + // window size C (contracting dimension size). Stride can be any size + // since there is only one window. + dims.contracting_dims.push_back({lhs, rhs, output, i}); + } else if (wd.stride() == 1 && wd.window_dilation() == 1 && + wd.base_dilation() == 1) { + if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 && + wd.padding_low() == 0 && !wd.window_reversal()) { + // A LHS non-contracting dimension can be represented as a spatial + // dimension with window size 1. + dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i}); + } else if (lhs_size == 1 && wd.size() == rhs_size && + wd.padding_high() == rhs_size - 1 && + wd.padding_low() == rhs_size - 1 && wd.window_reversal()) { + // A RHS non-contracting dimension can be represented as a spatial + // dimension with window size N (non-contracting dimension size), low + // padding N - 1, high padding N - 1 and window reversal. + dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i}); + } else { + return absl::nullopt; + } + } else { + return absl::nullopt; + } + } + + return dims; +} + +StatusOr> +CreateShardedConvForDotGeneralConvolution( + const HloInstruction& conv, + const DotGeneralAsConvolutionDimsInfo& dot_dnums, + HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) { + CHECK_EQ(conv.opcode(), HloOpcode::kConvolution); + const auto& conv_dnums = conv.convolution_dimension_numbers(); + auto window = conv.window(); + for (const auto& dim : dot_dnums.batch_dims) { + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_lhs_hlo->shape().dimensions( + conv_dnums.input_spatial_dimensions(dim.spatial_dim))); + wd->set_stride(std::max(1, wd->size() - 1)); + wd->set_base_dilation(wd->size()); + } + for (const auto& dim : dot_dnums.contracting_dims) { + if (dim.spatial_dim < 0) { + continue; + } + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_lhs_hlo->shape().dimensions( + conv_dnums.input_spatial_dimensions(dim.spatial_dim))); + } + for (const auto& dim : dot_dnums.rhs_non_contracting_dims) { + if (dim.spatial_dim < 0) { + continue; + } + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_rhs_hlo->shape().dimensions( + conv_dnums.kernel_spatial_dimensions(dim.spatial_dim))); + wd->set_padding_high(wd->size() - 1); + wd->set_padding_low(wd->size() - 1); + } + TF_ASSIGN_OR_RETURN(Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, conv_dnums)); + *sharded_conv_shape.mutable_layout() = conv.shape().layout(); + return HloInstruction::CreateConvolve( + sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, conv_dnums, conv.precision_config()); +} + +} // namespace dot_as_convolution_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.h b/tensorflow/compiler/xla/service/dot_as_convolution_util.h new file mode 100644 index 00000000000..a3e829a3d31 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.h @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +namespace dot_as_convolution_util { + +// Describes the dimensions of a convolution that can be interpreted as a dot. +struct DotGeneralAsConvolutionDimsInfo { + // The dimension numbers for the operands and output corresponding to a + // logical dimension (e.g., batch, contracting, non-contracting). If an + // operand or the output doesn't have the logical dimension, it is set to + // -1. + struct DimNums { + int64 lhs; + int64 rhs; + int64 output; + // The corresponding spatial dimension in the convolution's config. Set to + // -1 if it's not mapped to a spatial dimension. + int64 spatial_dim; + }; + std::vector batch_dims; + std::vector contracting_dims; + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; +}; + +// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can +// be interpreted as a dot, or absl::nullopt otherwise. +absl::optional ParseDotGeneralFromConvolution( + const HloInstruction* conv); + +// Creates sharded convolution instruction that can be interpreted as a dot. +// This is a utility for per-op partitioners. +// - 'conv' is the original convolution instruction. +// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'. +// - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result +// convolution instruction. +StatusOr> +CreateShardedConvForDotGeneralConvolution( + const HloInstruction& conv, + const DotGeneralAsConvolutionDimsInfo& dot_dnums, + HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); + +} // namespace dot_as_convolution_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 353a7f5cebc..40354dec3c6 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -31,7 +31,7 @@ namespace { // Convert a dot into a canonical form where non-contracting and contracting // dimensions are reshaped together and batch dimensions are the most major -// dimensions. The requires transposing and reshapes the lhs and rhs and +// dimensions. This requires transposing and reshapes of the lhs and rhs and // reshaping the output batch to the original shape. Status CanonicalizeDot(HloInstruction* original_dot) { auto computation = original_dot->parent(); @@ -80,7 +80,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { lhs_shape), original_dot->mutable_operand(0), lhs_transpose)); std::vector lhs_reshape_dims = batch_dim_sizes; - lhs_reshape_dims.push_back(lhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + lhs_reshape_dims.push_back(lhs_non_contracting_size); + } lhs_reshape_dims.push_back(lhs_contracting_size); // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_lhs = @@ -126,7 +128,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { std::vector rhs_reshape_dims = batch_dim_sizes; rhs_reshape_dims.push_back(rhs_contracting_size); - rhs_reshape_dims.push_back(rhs_non_contracting_size); + if (rhs_non_contracting_size > 1) { + rhs_reshape_dims.push_back(rhs_non_contracting_size); + } // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_rhs = computation->AddInstruction(HloInstruction::CreateReshape( @@ -134,15 +138,20 @@ Status CanonicalizeDot(HloInstruction* original_dot) { transposed_rhs)); std::vector dot_dims = batch_dim_sizes; - dot_dims.push_back(lhs_non_contracting_size); - dot_dims.push_back(rhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + dot_dims.push_back(lhs_non_contracting_size); + } + if (rhs_non_contracting_size > 1) { + dot_dims.push_back(rhs_non_contracting_size); + } DotDimensionNumbers dot_dnums; for (int64 i = 0; i < num_batch_dims; ++i) { dot_dnums.add_lhs_batch_dimensions(i); dot_dnums.add_rhs_batch_dimensions(i); } - dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_lhs_contracting_dimensions( + num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0)); dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( @@ -174,9 +183,9 @@ StatusOr DotDecomposer::Run(HloModule* module) { } // A dot is not canonical if it has more than one non-contracting // dimension. - if (dnums.lhs_batch_dimensions_size() + 2 != + if (dnums.lhs_batch_dimensions_size() + 2 < instruction->operand(0)->shape().rank() || - dnums.rhs_batch_dimensions_size() + 2 != + dnums.rhs_batch_dimensions_size() + 2 < instruction->operand(1)->shape().rank()) { non_canonical_dots.push_back(instruction); continue; diff --git a/tensorflow/compiler/xla/service/dot_decomposer_test.cc b/tensorflow/compiler/xla/service/dot_decomposer_test.cc index 67fff50eaf6..c4152393933 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer_test.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer_test.cc @@ -50,5 +50,75 @@ TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) { op::Shape("f32[4032,512]")))); } +TEST_F(DotDecomposerTest, DontCanonicalizeIfNoNoncontractingDims) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64]{0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_FALSE(canonicalized); +} + +TEST_F(DotDecomposerTest, DontAddLhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4,2,1]{3,2,1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + +TEST_F(DotDecomposerTest, DontAddRhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4,2,1]{3,2,1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/2, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 30300b8c195..8cb660de46c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2422,6 +2422,43 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( -> StatusOr { return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; + case HloOpcode::kMap: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + std::vector operands; + for (int i = 0; i < hlo->operand_count(); i++) { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))(index)); + operands.push_back(operand_value); + } + std::vector input_generators; + for (const HloInstruction* instr : hlo->operands()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + return EmitElementalMap(Cast(hlo), operands); + }; + case HloOpcode::kReduceWindow: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return EmitElementalReduceWindow( + Cast(hlo), + operand_to_generator.at(hlo->operand(0)), + operand_to_generator.at(hlo->operand(1)), index); + }; + case HloOpcode::kReduce: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + auto reduce_instr = Cast(hlo); + std::vector input_generators; + for (const HloInstruction* instr : reduce_instr->inputs()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + + std::vector initial_value_generators; + for (const HloInstruction* instr : reduce_instr->init_values()) { + initial_value_generators.push_back(operand_to_generator.at(instr)); + } + return EmitElementalReduce(reduce_instr, std::move(input_generators), + std::move(initial_value_generators), index); + }; default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", @@ -2451,4 +2488,215 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, return complex; } +StatusOr ElementalIrEmitter::EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands) { + TF_ASSIGN_OR_RETURN( + std::vector values, + EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands, + llvm_ir::IrName(map_instr))); + CHECK_EQ(values.size(), 1); + return values[0]; +} + +StatusOr ElementalIrEmitter::EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index) { + // Pseudocode: + // for each index I in output + // value = init_value + // for each index W in window + // for each dimension i from 0 to rank - 1 + // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i] + // if I in bounds of input + // value = function(value, input[I]) + // output[O] = value + const HloInstruction* operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + + PrimitiveType operand_element_type = operand->shape().element_type(); + llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), + "reduce_window_accum_ptr", b_); + { + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); + Store(init_value, accum_ptr); + } + + llvm::Type* index_type = index.GetType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return index.GetConstantWithIndexType(c); + }; + + llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const IrArray::Index window_index = loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + CHECK_EQ(window_index.size(), index.size()); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); + + std::vector input_multi_index(index.size()); + llvm::Value* in_bounds = b_->getInt1(true); + for (size_t i = 0; i < index.size(); ++i) { + llvm::Value* stridden_index = + NSWMul(index[i], index_typed_const(window.dimensions(i).stride())); + input_multi_index[i] = NSWSub( + NSWAdd( + stridden_index, + NSWMul(window_index[i], + index_typed_const(window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = + ICmpEQ(SRem(input_multi_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. + input_multi_index[i] = + SDiv(input_multi_index[i], + index_typed_const(window.dimensions(i).base_dilation())); + + // We must check whether 0 <= input_multi_index[i] < bound, as + // otherwise we are in the pad and so can skip the computation. This + // comparison is equivalent to the unsigned comparison + // input_multi_index[i] < bound, as a negative value wraps to a large + // positive value. + in_bounds = And(in_bounds, + ICmpULT(input_multi_index[i], + index_typed_const(operand->shape().dimensions(i)))); + } + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); + + // We are not in pad, so do the computation. + IrArray::Index input_index(input_multi_index, operand->shape(), index_type); + TF_ASSIGN_OR_RETURN(llvm::Value * input_value, input_generator(input_index)); + TF_ASSIGN_OR_RETURN( + std::vector accum_values, + EmitThreadLocalCall(*reduce_window->to_apply(), + {Load(accum_ptr), input_value}, "reducer_function")); + CHECK_EQ(accum_values.size(), 1); + Store(accum_values[0], accum_ptr); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); + return Load(accum_ptr); +} + +StatusOr ElementalIrEmitter::EmitElementalReduce( + const HloReduceInstruction* reduce, + std::vector input_generators, + std::vector initial_value_generators, + const llvm_ir::IrArray::Index& index) { + const Shape& out_shape = reduce->shape(); + bool is_variadic = !out_shape.IsArray(); + int accumulators_count = 1; + if (is_variadic) { + CHECK(out_shape.IsTuple()); + accumulators_count = out_shape.tuple_shapes_size(); + } + + absl::Span reduced_dimensions(reduce->dimensions()); + + std::vector accumulator_addrs; + std::vector accumulator_types; + llvm::Type* index_type = index.GetType(); + for (int i = 0; i < accumulators_count; i++) { + const Shape& element_shape = + is_variadic ? out_shape.tuple_shapes(i) : out_shape; + PrimitiveType accumulator_type = element_shape.element_type(); + llvm::Type* accumulator_llvm_type = + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + accumulator_types.push_back(accumulator_llvm_type); + + // Initialize an accumulator with init_value. + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + accumulator_llvm_type, "accumulator_" + std::to_string(i), b()); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generators[i](llvm_ir::IrArray::Index(index_type))); + Store(init_value, accumulator_addr); + accumulator_addrs.push_back(accumulator_addr); + } + + // The enclosing loops go over all the target elements. Now we have to compute + // the actual target element. For this, we build a new loop nest to iterate + // over all the reduction dimensions in the argument. + // AddLoopsForShapeOnDimensions will return an Index where induction Value*s + // are placed for each dimension in dimensions, and all the rest are nullptrs. + llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type); + const HloInstruction* arg = reduce->operand(0); + std::vector input_multi_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); + + // Build a full index for the input argument, using input_multi_index as the + // base. In input_multi_index only the reduction dimensions are filled in. We + // fill in the rest of the dimensions with induction Value*s taken from + // 'index' which iterates over the target array. See the high-level + // description in the XLA documentation for details. + auto it = index.begin(); + + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; + } + } + CHECK(index.end() == it); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + index_type); + + std::vector reduction_operands; + for (llvm::Value* accum : accumulator_addrs) { + llvm::Value* accum_value = Load(accum); + reduction_operands.push_back(accum_value); + } + + for (int i = 0; i < accumulators_count; i++) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, + input_generators[i](input_index)); + reduction_operands.push_back(input_element); + } + + TF_ASSIGN_OR_RETURN( + std::vector results, + EmitThreadLocalCall(*reduce->to_apply(), reduction_operands, + "reduce_function")); + + CHECK(results.size() == accumulators_count); + for (int i = 0; i < accumulators_count; i++) { + Store(results[i], accumulator_addrs[i]); + } + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); + + if (is_variadic) { + // Emit a structure, as that what the LoopEmitter expects. + llvm::Value* returned_structure = llvm::UndefValue::get( + llvm::StructType::get(b()->getContext(), accumulator_types)); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* accumulator_value = Load(accumulator_addrs[i]); + returned_structure = + b()->CreateInsertValue(returned_structure, accumulator_value, i); + } + return returned_structure; + } else { + CHECK_EQ(accumulator_addrs.size(), 1); + return Load(accumulator_addrs[0]); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 94e8f1d6400..06a9d7b194c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -17,12 +17,17 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ #include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" @@ -220,6 +225,26 @@ class ElementalIrEmitter : public IrBuilderMixin { const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& dot_result_index); + virtual StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) = 0; + + StatusOr EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands); + + StatusOr EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index); + + StatusOr EmitElementalReduce( + const HloReduceInstruction* reduce, + std::vector input_generators, + std::vector initial_value_generators, + const llvm_ir::IrArray::Index& index); + llvm::IRBuilder<>* const b_; llvm::Module* module_; diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index f1ac1fef451..5d7bd26b01e 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -63,10 +63,6 @@ class ExecutionInput { explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {} explicit ExecutionInput(ShapeTree buffers) : buffers_(std::move(buffers)) {} - ExecutionInput(ShapeTree buffers, - std::vector owner_held_indices) - : buffers_(std::move(buffers)), - unowned_indices_(std::move(owner_held_indices)) {} ExecutionInput(ExecutionInput&&) = default; ~ExecutionInput() { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 61bc41283e1..09382cfec3f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -17,15 +17,15 @@ load( "tf_cuda_library", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm", + "if_rocm_is_configured", +) load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", -) load("//tensorflow:tensorflow.bzl", "if_nccl") package( @@ -684,7 +684,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", @@ -720,7 +720,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -819,6 +819,7 @@ cc_library( deps = [ # LINT.IfChange "@local_config_cuda//cuda:cublas_headers", + "@local_config_cuda//cuda:cusolver_headers", # LINT.ThenChange(//tensorflow/copy.bara.sky:cublas_headers) "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -901,12 +902,15 @@ cc_library( ":ir_emission_utils", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) @@ -1674,7 +1678,7 @@ tf_proto_library_cc( protodeps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:autotuning_proto", + "//tensorflow/core/protobuf:autotuning_proto", ], ) @@ -1685,8 +1689,8 @@ cc_library( deps = [ ":gpu_autotuning_proto_cc", "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/core:autotuning_proto_cc", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/protobuf:autotuning_proto_cc", "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c6df786fb51..eee0fc83481 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -260,6 +260,13 @@ StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, llvm::Value* value) { + // When F64 is being requested, assume performance is less important and use + // the more numerically precise tanh function. + if (prim_type == F64) { + return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value}, + {prim_type}, prim_type); + } + // Emit a fast approximation of tanh instead of calling __nv_tanh. // __nv_tanh is particularly bad because it contains branches, thus // preventing LLVM's load-store vectorizer from working its magic across a @@ -305,168 +312,5 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() { return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } -llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) { - switch (hlo->opcode()) { - case HloOpcode::kMap: - return [=, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - TF_RET_CHECK(!hlo->operands().empty()) - << "Zero operand map not implemented in GPU backend."; - TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0); - std::vector operand_elements; - for (HloInstruction* operand : hlo->operands()) { - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(index)); - operand_elements.push_back(value); - } - return compute_nested_(*hlo->to_apply(), operand_elements); - }; - case HloOpcode::kReduceWindow: - // Pseudocode: - // for each index I in output - // value = init_value - // for each index W in window - // for each dimension i from 0 to rank - 1 - // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i] - // if I in bounds of input - // value = function(value, input[I]) - // output[O] = value - return [=, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - const HloInstruction* operand = hlo->operand(0); - const Window& window = hlo->window(); - - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accum_ptr", b_); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))( - IrArray::Index(index.GetType()))); - Store(init_value, accum_ptr); - } - - llvm::Type* index_type = index.GetType(); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { - return index.GetConstantWithIndexType(c); - }; - - llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); - - std::vector input_multi_index(index.size()); - llvm::Value* in_bounds = b_->getInt1(true); - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = NSWMul( - index[i], index_typed_const(window.dimensions(i).stride())); - input_multi_index[i] = NSWSub( - NSWAdd(stridden_index, - NSWMul(window_index[i], - index_typed_const( - window.dimensions(i).window_dilation()))), - index_typed_const(window.dimensions(i).padding_low())); - - // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = ICmpEQ( - SRem(input_multi_index[i], - index_typed_const(window.dimensions(i).base_dilation())), - index_typed_const(0)); - in_bounds = And(in_bounds, dilation_condition); - - // Apply base dilation to the index. - input_multi_index[i] = - SDiv(input_multi_index[i], - index_typed_const(window.dimensions(i).base_dilation())); - - // We must check whether 0 <= input_multi_index[i] < bound, as - // otherwise we are in the pad and so can skip the computation. This - // comparison is equivalent to the unsigned comparison - // input_multi_index[i] < bound, as a negative value wraps to a large - // positive value. - in_bounds = - And(in_bounds, - ICmpULT(input_multi_index[i], - index_typed_const(operand->shape().dimensions(i)))); - } - - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); - SetToFirstInsertPoint(if_data.true_block, b_); - - // We are not in pad, so do the computation. - IrArray::Index input_index(input_multi_index, operand->shape(), - index_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_value, - operand_to_generator.at(operand)(input_index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); - Store(accum_value, accum_ptr); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return Load(accum_ptr); - }; - case HloOpcode::kReduce: - // TODO(b/118332391): This should be supported. - CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; - return [=, &operand_to_generator]( - const IrArray::Index& output_index) -> StatusOr { - const HloInstruction* operand = hlo->operand(0); - llvm::Value* accum_ptr = - b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), module_)); - llvm::Type* index_type = output_index.GetType(); - TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))( - IrArray::Index(index_type))); - b()->CreateStore(init_value, accum_ptr); - - llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions( - operand->shape(), hlo->dimensions(), "reduction_dim"); - if (!ShapeUtil::IsScalar(hlo->shape())) { - // Here only input_multi_index[hlo->dimensions()] are non-null, so we - // must set the rest. - size_t j = 0; - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = output_index[j++]; - } - } - CHECK_EQ(output_index.size(), j); - } - llvm_ir::IrArray::Index input_index( - input_multi_index, hlo->operand(0)->shape(), index_type); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); - TF_ASSIGN_OR_RETURN( - llvm::Value * input_value, - operand_to_generator.at(hlo->operand(0))(input_index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b()->CreateLoad(accum_ptr), input_value})); - b()->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); - return b()->CreateLoad(accum_ptr); - }; - default: - return ElementalIrEmitter::MakeElementGenerator(hlo, - operand_to_generator); - } -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index c8a58a21980..a3056b1ddad 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -40,17 +40,13 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { public: // A NestedComputer computes an element of the output of the given computation // given a Span of its input elements. - using NestedComputer = std::function( + using NestedComputer = std::function>( const HloComputation&, absl::Span)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested); - llvm_ir::ElementGenerator MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) override; - protected: StatusOr EmitFloatBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, @@ -92,6 +88,12 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitComplexAbs(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view) override { + return compute_nested_(callee, parameters); + } + llvm::Value* EmitThreadId() override; private: diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 991a463f2a0..9d6be3c78ea 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -60,16 +60,18 @@ StatusOr> FftScratchAllocator::AllocateBytes( namespace { -se::fft::Type FftTypeToSeType(FftType type) { +se::fft::Type FftTypeToSeType(FftType type, bool double_precision) { switch (type) { case FftType::FFT: - return se::fft::Type::kC2CForward; + return double_precision ? se::fft::Type::kZ2ZForward + : se::fft::Type::kC2CForward; case FftType::IFFT: - return se::fft::Type::kC2CInverse; + return double_precision ? se::fft::Type::kZ2ZInverse + : se::fft::Type::kC2CInverse; case FftType::IRFFT: - return se::fft::Type::kC2R; + return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R; case FftType::RFFT: - return se::fft::Type::kR2C; + return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C; default: LOG(FATAL) << "unsupported fft type"; } @@ -78,12 +80,16 @@ se::fft::Type FftTypeToSeType(FftType type) { string FftTypeToString(se::fft::Type type) { switch (type) { case se::fft::Type::kC2CForward: + case se::fft::Type::kZ2ZForward: return "FFT"; case se::fft::Type::kC2CInverse: + case se::fft::Type::kZ2ZInverse: return "IFFT"; case se::fft::Type::kC2R: + case se::fft::Type::kZ2D: return "IRFFT"; case se::fft::Type::kR2C: + case se::fft::Type::kD2Z: return "RFFT"; default: LOG(FATAL) << "unknown fft type"; @@ -98,7 +104,9 @@ FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, const Shape& input_shape, const Shape& output_shape, const HloInstruction* hlo) : Thunk(Kind::kFft, hlo), - fft_type_(FftTypeToSeType(fft_type)), + fft_type_( + FftTypeToSeType(fft_type, input_shape.element_type() == F64 || + input_shape.element_type() == C128)), fft_length_(fft_length.begin(), fft_length.end()), scale_factor_(1.0f), input_buffer_(input_buffer), @@ -166,6 +174,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); break; } + case se::fft::Type::kZ2ZForward: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } case se::fft::Type::kC2CInverse: { se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); @@ -181,6 +198,22 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { } break; } + case se::fft::Type::kZ2ZInverse: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = + stream + .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + complex128(scale_factor_), &output_data, 1) + .ok(); + } + break; + } case se::fft::Type::kR2C: { se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); @@ -190,6 +223,15 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); break; } + case se::fft::Type::kD2Z: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } case se::fft::Type::kC2R: { se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); @@ -205,6 +247,21 @@ Status FftThunk::ExecuteOnStream(const ExecuteParams& params) { } break; } + case se::fft::Type::kZ2D: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream.ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = stream + .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + scale_factor_, &output_data, 1) + .ok(); + } + break; + } default: LOG(FATAL) << "unsupported fft type"; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 2df6b50d361..b9d1f3ef158 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -176,7 +176,6 @@ Status GpuExecutable::ExecuteThunks( // module, we won't get any data, but that's probably an OK trade-off. ScopedAnnotation annotation([&] { return thunk->profile_annotation(); }); - TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); int32 stream_no = thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); se::Stream* stream = @@ -387,6 +386,10 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( assignment_.get(), executor->device_ordinal(), memory_allocator)); } + for (Thunk* thunk : thunk_schedule_->TotalOrder()) { + TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); + } + TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, block_host_until_done, hlo_execution_profile)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 1316e8ad1aa..bb4184ff76f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -351,6 +351,9 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, const HloInstruction& instr2) { if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) > kSharedMemoryBudgetInBytes) { + VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString() + << " and " << instr2.ToString() << " would be over the budget of " + << kSharedMemoryBudgetInBytes << "B"; return true; } @@ -383,6 +386,14 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, num_output_buffers <= kMaxOperandsAndOutputsPerFusion) { return false; + } else { + VLOG(5) << "Operand count of " + << "(" << instr1.ToString() << " ) = " << instr1.operand_count() + << " and ( " << instr2.ToString() + << " ) = " << instr2.operand_count() + << " and num_output_buffers = " << num_output_buffers + << " is bigger than the bound of " + << kMaxOperandsAndOutputsPerFusion; } // Compute the precise number of operands to the new fusion. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 05fa798dc39..cb22b4d9042 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -96,7 +96,8 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %d bytes", size); + return InvalidArgument("GPU infeed of %d bytes exceeds maximum of %d bytes", + size, std::numeric_limits::max()); } if (size == 0) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index fc1c1bb4ab1..a0580e2ab04 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -65,12 +65,16 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) { + VLOG(5) << "Not fusing inexpensive checks of operand " << operand_index + << " of " << consumer->ToString(); return false; } auto producer = consumer->operand(operand_index); // The following checks are potentially expensive. if (FusionWouldBeTooLarge(*consumer, *producer)) { + VLOG(5) << "Fusion of (" << producer->ToString() << ") into (" + << consumer->ToString() << ") would be too large"; return false; } if (consumer->opcode() != HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 011eb07d3bd..aa8a6215cc7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -222,7 +222,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( // Derive a minimum alignment from the type. The optimizer can increase it // later. store->setAlignment( - llvm::MaybeAlign(ShapeUtil::ByteSizeOfPrimitiveType(element_type))); + llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type))); return true; } @@ -698,115 +698,6 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status IrEmitter::HandleReduce(HloInstruction* instr) { - const HloReduceInstruction* reduce = Cast(instr); - const Shape& out_shape = reduce->shape(); - bool returns_tuple = !out_shape.IsArray(); - int accumulators_count = 1; - if (returns_tuple) { - CHECK(out_shape.IsTuple()); - accumulators_count = out_shape.tuple_shapes_size(); - } - - auto arg = reduce->operand(0); - absl::Span dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - return EmitTargetElementLoop( - *reduce, - [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - std::vector accumulator_addrs; - std::vector accumulator_types; - - // Initialize accumulators with initial values. - for (int i = 0; i < accumulators_count; i++) { - auto init_value = reduce->init_values()[i]; - const Shape& element_shape = - returns_tuple ? out_shape.tuple_shapes(i) : out_shape; - PrimitiveType accumulator_type = element_shape.element_type(); - llvm::Type* accumulator_llvm_type = - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); - llvm::AllocaInst* accumulator_addr = Alloca(accumulator_llvm_type); - Store(Load(GetBasePointer(*init_value)), accumulator_addr); - accumulator_addrs.push_back(accumulator_addr); - accumulator_types.push_back(accumulator_llvm_type); - } - - // The enclosing loops go over all the target elements. Now we have to - // compute the actual target element. For this, we build a new loop nest - // to iterate over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction - // Value*s are placed for each dimension in dimensions, and all the rest - // are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - // Build a full index for the input argument, using reduced_dims_index - // as the base. In reduced_dims_index only the reduction dimensions are - // filled in. We fill in the rest of the dimensions with induction - // Value*s taken from 'index' which iterates over the target array. - // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = *it++; - } - } - CHECK(index.end() == it); - - // Apply the reduction function to the loaded value. - llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), - b_.getInt64Ty()); - std::vector reduction_operands(accumulator_addrs.begin(), - accumulator_addrs.end()); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* input_address = - GetIrArray(*reduce->operand(i), *reduce) - .EmitArrayElementAddress(input_index, &b_); - reduction_operands.push_back(input_address); - } - - llvm::Value* ret_argument; - if (!returns_tuple) { - CHECK_EQ(accumulator_addrs.size(), 1); - ret_argument = accumulator_addrs[0]; - } else { - const Shape& return_shape = function->root_instruction()->shape(); - - llvm::Type* return_value_buffer_type = - llvm_ir::ShapeToIrType(return_shape, module_); - ret_argument = Alloca(return_value_buffer_type); - llvm_ir::IrArray tuple_array(ret_argument, return_shape); - EmitTuple(tuple_array, accumulator_addrs, &b_); - } - - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *function, reduction_operands, ret_argument)); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - - if (!returns_tuple) { - CHECK_EQ(accumulator_addrs.size(), 1); - return Load(accumulator_addrs[0]); - } else { - // Emit a struct for the LoopEmitter dealing with multi-output - // fusion. - llvm::Value* returned_structure = llvm::UndefValue::get( - llvm::StructType::get(b_.getContext(), accumulator_types)); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* accumulator_value = Load(accumulator_addrs[i]); - returned_structure = - b_.CreateInsertValue(returned_structure, accumulator_value, i); - } - return returned_structure; - } - }); -} - Status IrEmitter::HandleFusion(HloInstruction* fusion) { // kFusion for library calls should be handled by // IrEmitterUnnested::HandleFusion. @@ -866,22 +757,39 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -StatusOr IrEmitter::ComputeNestedElement( +StatusOr> IrEmitter::ComputeNestedElement( const HloComputation& computation, absl::Span parameter_elements) { + const Shape& return_shape = computation.root_instruction()->shape(); llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType( - computation.root_instruction()->shape().element_type(), module_), - "return_buffer", &b_); + llvm_ir::ShapeToIrType(return_shape, module_), "return_buffer", &b_); std::vector parameter_buffers; for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); Store(parameter_element, parameter_buffers.back()); } + + std::vector allocas_for_returned_scalars; + if (!return_shape.IsTuple()) { + allocas_for_returned_scalars.push_back(return_buffer); + } else { + allocas_for_returned_scalars = + llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_); + llvm_ir::IrArray tuple_array(return_buffer, return_shape); + + EmitTuple(tuple_array, allocas_for_returned_scalars, &b_); + } + TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return Load(return_buffer); + + std::vector returned_scalars; + returned_scalars.reserve(allocas_for_returned_scalars.size()); + for (llvm::Value* addr : allocas_for_returned_scalars) { + returned_scalars.push_back(Load(addr)); + } + return returned_scalars; } std::vector IrEmitter::ConstructIrArrayForOutputs( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index e0fe454dcfe..93712961ea2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -89,7 +89,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleRecv(HloInstruction* recv) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; @@ -213,7 +212,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, const llvm_ir::IrArray::Index& compare_keys_index, const llvm_ir::IrArray& keys_array); - StatusOr ComputeNestedElement( + StatusOr> ComputeNestedElement( const HloComputation& computation, absl::Span parameter_elements); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ec5f10bd2e8..a78ffc8dd1a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2016,7 +2016,9 @@ void IrEmitterUnnested::EmitTile( // True iff all threads always execute all instructions in the tiling // dimension X. - bool x_tile_fits = mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0; + bool x_tile_fits = + mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0 && + mapping_scheme.GetRowContiguous(); // The outer loop below is simply doing: // @@ -2731,7 +2733,8 @@ void IrEmitterUnnested::EmitHlo021Tile( /*num_threads_y=*/kNumRows, /*num_threads_x=*/kWarpSize, /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1); + /*vector_size=*/1, + /*is_row_contiguous=*/false); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); llvm::Type* index_type = diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h index 5e15d0767a1..d5c4ecbc795 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h @@ -90,13 +90,14 @@ class KernelMappingScheme { KernelMappingScheme(absl::Span dims_in_elems, absl::Span tile_sizes, int64 num_threads_y, int64 num_threads_x, IndexingOrder indexing_order, - int vector_size) + int vector_size, bool is_row_contiguous = false) : dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]}, tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y), indexing_order_(indexing_order), - vector_size_(vector_size) { + vector_size_(vector_size), + is_row_contiguous_(is_row_contiguous) { CHECK_EQ(tile_sizes[1] % num_threads_y_, 0); CHECK_EQ(tile_sizes[2] % num_threads_x_, 0); VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ","); @@ -134,6 +135,7 @@ class KernelMappingScheme { IndexingOrder GetIndexingOrder() const { return indexing_order_; } int GetVectorSize() const { return vector_size_; } + bool GetRowContiguous() const { return is_row_contiguous_; } private: // The number of elements in each dimension. @@ -159,6 +161,7 @@ class KernelMappingScheme { // to trigger vectorized loads on GPUs while keeping memory // coalescing. const int vector_size_; + const bool is_row_contiguous_; }; // Information to support the code generation for a tiled reduction kernel. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 060a0375271..497dcda4361 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -689,7 +689,7 @@ std::unique_ptr AMDGPUGetTargetMachine( llvm::Triple target_triple, int amdgpu_version, const HloModuleConfig& hlo_module_config) { return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version), - hlo_module_config, "-code-object-v3"); + hlo_module_config, "+code-object-v3"); } void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 8d2ef53bfa9..e60f3bc3c14 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -16,7 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ -#include "tensorflow/compiler/xla/service/multi_output_fusion.h" +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 0196267d904..eefa4661d37 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -385,6 +385,19 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( } else { if (maybe_cubin.status().code() == tensorflow::error::Code::NOT_FOUND) { + if (!hlo_module_config.debug_options() + .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) { + PrintCantFindCudaMessage( + "Can't find ptxas binary in ${CUDA_DIR}/bin. Custom ptxas " + "location can be specified using $PATH.", + hlo_module_config); + LOG(FATAL) + << "Can't find ptxas binary. You can pass the flag " + "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found " + "to use the GPU driver for compiling ptx instead. However " + "this option is discouraged and can lead to increased " + "memory concumptions and other subtle runtime issues."; + } // Missing ptxas is expected in some environments where CUDA SDK // binaries are not available. We don't want to spam logs with // identical warnings in this case. @@ -402,25 +415,13 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( "using $PATH.", hlo_module_config); } - CHECK(hlo_module_config.debug_options() - .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) - << "There was an error when trying to compile ptx into sass " - "code. If you want to try falling back to the GPU driver to " - "jit compile ptx, you can use the flag " - "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found." - " Use at your own risk though, it has known drawbacks like " - "increased memory consumption."; } else { - LOG(ERROR) << "Error during compilation of ptx to sass: " - << maybe_cubin.status(); - CHECK(hlo_module_config.debug_options() - .xla_gpu_unsafe_fallback_to_driver_on_ptxas_error()) - << "There was an error when trying to compile ptx into sass " - "code. If you want to try falling back to the GPU driver to " - "jit compile ptx, you can use the flag " - "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_error." - " Use at your own risk though, it has known drawbacks like " - "increased memory consumption."; + LOG(FATAL) << "ptxas returned an error during compilation of ptx " + "to sass: '" + << maybe_cubin.status() << "' " + << "If the error message indicates that a file could " + "not be written, please verify that sufficient " + "filesystem space is provided."; } // We're going to use the driver to JIT our PTX->SASS, so warn if diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc index 49eadd8c6be..31b590a19ff 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.cc +++ b/tensorflow/compiler/xla/service/gpu/target_util.cc @@ -111,47 +111,50 @@ struct TargetDeviceFunction { struct TargetDeviceFunction GetDeviceFunctionRoot( TargetDeviceFunctionID func_id) { switch (func_id) { - case TargetDeviceFunctionID::kPow: { - return {"__nv_pow", "__ocml_pow"}; - } - case TargetDeviceFunctionID::kErfcinv: { - return {"__nv_erfcinv", "__ocml_erfcinv"}; - } - case TargetDeviceFunctionID::kLog: { - return {"__nv_log", "__ocml_log"}; - } - case TargetDeviceFunctionID::kLog1p: { - return {"__nv_log1p", "__ocml_log1p"}; - } - case TargetDeviceFunctionID::kSin: { - return {"__nv_sin", "__ocml_sin"}; + case TargetDeviceFunctionID::kAtan2: { + return {"__nv_atan2", "__ocml_atan2"}; } case TargetDeviceFunctionID::kCos: { return {"__nv_cos", "__ocml_cos"}; } + case TargetDeviceFunctionID::kErfcinv: { + return {"__nv_erfcinv", "__ocml_erfcinv"}; + } case TargetDeviceFunctionID::kExp: { return {"__nv_exp", "__ocml_exp"}; } case TargetDeviceFunctionID::kExpm1: { return {"__nv_expm1", "__ocml_expm1"}; } - case TargetDeviceFunctionID::kSqrt: { - return {"__nv_sqrt", "__ocml_sqrt"}; - } - case TargetDeviceFunctionID::kRsqrt: { - return {"__nv_rsqrt", "__ocml_rsqrt"}; - } - case TargetDeviceFunctionID::kAtan2: { - return {"__nv_atan2", "__ocml_atan2"}; - } case TargetDeviceFunctionID::kFmod: { return {"__nv_fmod", "__ocml_fmod"}; } + case TargetDeviceFunctionID::kHypot: { + return {"__nv_hypot", "__ocml_hypot"}; + } + case TargetDeviceFunctionID::kLog: { + return {"__nv_log", "__ocml_log"}; + } + case TargetDeviceFunctionID::kLog1p: { + return {"__nv_log1p", "__ocml_log1p"}; + } + case TargetDeviceFunctionID::kPow: { + return {"__nv_pow", "__ocml_pow"}; + } case TargetDeviceFunctionID::kRound: { return {"__nv_round", "__ocml_round"}; } - case TargetDeviceFunctionID::kHypot: { - return {"__nv_hypot", "__ocml_hypot"}; + case TargetDeviceFunctionID::kRsqrt: { + return {"__nv_rsqrt", "__ocml_rsqrt"}; + } + case TargetDeviceFunctionID::kSin: { + return {"__nv_sin", "__ocml_sin"}; + } + case TargetDeviceFunctionID::kSqrt: { + return {"__nv_sqrt", "__ocml_sqrt"}; + } + case TargetDeviceFunctionID::kTanh: { + return {"__nv_tanh", "__ocml_tanh"}; } } } diff --git a/tensorflow/compiler/xla/service/gpu/target_util.h b/tensorflow/compiler/xla/service/gpu/target_util.h index 4355ed21136..2bdaea7734a 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.h +++ b/tensorflow/compiler/xla/service/gpu/target_util.h @@ -46,20 +46,21 @@ enum class TargetIntrinsicID { // Enumeration to get target specific device math function. enum class TargetDeviceFunctionID { - kPow = 0, - kErfcinv, - kLog, - kLog1p, - kSin, + kAtan2 = 0, kCos, + kErfcinv, kExp, kExpm1, - kSqrt, - kRsqrt, - kAtan2, kFmod, + kHypot, + kLog, + kLog1p, + kPow, kRound, - kHypot + kRsqrt, + kSin, + kSqrt, + kTanh, }; // Emits IR to call a device function named "callee_name" on the given diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index e04dba418d9..7a9845d0f49 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -235,6 +235,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gpu_copy_alone_test", + srcs = [ + "gpu_copy_alone_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc new file mode 100644 index 00000000000..1c475ab4e10 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +namespace xla { +namespace gpu { + +namespace { + +// WARNING: This tests must be alone in its file! Otherwise, the +// error isn't caught. We expect and CUDA_ERROR_ILLEGAL_ADDRESS to be +// thrown with the old buggy code. +class CopyAloneNoOptTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test MultiOutputStore contain a MOF fusion and XLA optimizer pass + // doesn't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(CopyAloneNoOptTest, CopyTranspose) { + const char* hlo_text = R"( +HloModule mod +ENTRY main { + %param = f32[8,32,32,32,16]{4,3,2,1,0} parameter(0) + ROOT %copy = f32[8,32,32,32,16]{3,2,1,4,0} copy(f32[8,32,32,32,16]{4,3,2,1,0} %param) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + CompileAndOptionallyVerifyPtx(std::move(optimized_module), + R"( +CHECK-NOT: ld.global.nc.v2 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94a4df43cf4..8a31bc5fef4 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -707,6 +707,10 @@ Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) { return Status::OK(); } +Status HloCostAnalysis::HandleAllGather(const HloInstruction* hlo) { + return Status::OK(); +} + Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. @@ -732,6 +736,16 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } +Status HloCostAnalysis::HandleCollectivePermuteStart( + const HloInstruction* /*hlo*/) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleCollectivePermuteDone( + const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) { return Status::OK(); } @@ -1027,6 +1041,42 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } +int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo, + absl::optional memory_space) const { + int64 bytes_read = 0; + for (int operand_number = 0; operand_number < hlo.operand_count(); + ++operand_number) { + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) { + absl::optional index_memory_space; + if (indexed_shape.shape.has_layout()) { + index_memory_space = indexed_shape.shape.layout().memory_space(); + } + if (!memory_space || memory_space == index_memory_space) { + bytes_read += + operand_bytes_accessed(hlo, operand_number, indexed_shape.index); + } + } + } + return bytes_read; +} + +int64 HloCostAnalysis::GetBytesWritten( + const HloInstruction& hlo, absl::optional memory_space) const { + int64 bytes_written = 0; + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(hlo.shape())) { + absl::optional index_memory_space; + if (indexed_shape.shape.has_layout()) { + index_memory_space = indexed_shape.shape.layout().memory_space(); + } + if (!memory_space || memory_space == index_memory_space) { + bytes_written += output_bytes_accessed(hlo, indexed_shape.index); + } + } + return bytes_written; +} + StatusOr HloCostAnalysis::ProcessSubcomputation( HloComputation* computation) { auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 915c4dcbe84..d9085dd7785 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -76,9 +76,12 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleFft(const HloInstruction* fft) override; Status HandleTriangularSolve(const HloInstruction* hlo) override; Status HandleCholesky(const HloInstruction* hlo) override; + Status HandleAllGather(const HloInstruction* hlo) override; Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; + Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; + Status HandleCollectivePermuteDone(const HloInstruction* hlo) override; Status HandleReplicaId(const HloInstruction* hlo) override; Status HandlePartitionId(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; @@ -161,6 +164,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor { ShapeIndex index = {}) const; float optimal_seconds(const HloInstruction& hlo) const; + // Get bytes read/written by this HLO. If memory_space is provided, it returns + // the bytes read/written from/to the given memory space only. + int64 GetBytesRead(const HloInstruction& hlo, + absl::optional memory_space = absl::nullopt) const; + int64 GetBytesWritten( + const HloInstruction& hlo, + absl::optional memory_space = absl::nullopt) const; + const Properties& properties() const { return properties_sum_; } const float property(const string& key) const { return GetProperty(key, properties()); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 4894e566393..d0d533e0b06 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -676,6 +676,39 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { } } +bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet( + HloInstruction* collective_permute_start) { + CHECK_EQ(collective_permute_start->opcode(), + HloOpcode::kCollectivePermuteStart); + bool changed = false; + // CollectivePermuteStart forwards the operand value to element {0} of its + // output. + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_start->operand(0)); + HloValueSet& value_set = GetValueSet(collective_permute_start, {0}); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + return changed; +} + +bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet( + HloInstruction* collective_permute_done) { + CHECK_EQ(collective_permute_done->opcode(), + HloOpcode::kCollectivePermuteDone); + bool changed = false; + // CollectivePermuteDone forwards the operand value at {0} to its output. + const HloValueSet& operand_value_set = + GetValueSet(collective_permute_done->operand(0), {1}); + HloValueSet& value_set = GetValueSet(collective_permute_done); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + return changed; +} + bool HloDataflowAnalysis::UpdateInstructionValueSet( HloInstruction* instruction) { // Recompute from operands. @@ -712,6 +745,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateCopyDoneValueSet(instruction); case HloOpcode::kConditional: return UpdateConditionalValueSet(instruction); + case HloOpcode::kCollectivePermuteStart: + return UpdateCollectivePermuteStartValueSet(instruction); + case HloOpcode::kCollectivePermuteDone: + return UpdateCollectivePermuteDoneValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 75bcf7ea318..bec592aeb20 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -216,6 +216,10 @@ class HloDataflowAnalysis { bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); + bool UpdateCollectivePermuteStartValueSet( + HloInstruction* collective_permute_start); + bool UpdateCollectivePermuteDoneValueSet( + HloInstruction* collective_permute_done); // Propagates the dataflow through the module. In particular, it propagates // the HloValueSet from its defining instruction to the users of the diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index b8e3f83b515..900b557b4dc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -47,16 +47,14 @@ StatusOr HloDCE::RunOnComputation( // computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { + auto maybe_collective_op = DynCast(instruction); if (instruction != computation->root_instruction() && instruction->user_count() == 0 && computation->IsSafelyRemovable(instruction) && (!instruction->HasSideEffect() || (remove_cross_partition_collective_ops && - ((instruction->opcode() == HloOpcode::kAllReduce && - !Cast(instruction)->constrain_layout()) || - (instruction->opcode() == HloOpcode::kAllToAll && - !Cast(instruction)->constrain_layout()) || - instruction->opcode() == HloOpcode::kCollectivePermute)))) { + (maybe_collective_op != nullptr && + !maybe_collective_op->constrain_layout())))) { dead_roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 106ebb7be0e..02443ff3c3c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -2556,6 +2556,20 @@ std::unique_ptr> HloEvaluator::MatmulArray2D( lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); } +std::unique_ptr>> HloEvaluator::MatmulArray2D( + const Array2D>& lhs, + const Array2D>& rhs) { + return MatmulArray2DImpl>( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC64); +} + +std::unique_ptr>> HloEvaluator::MatmulArray2D( + const Array2D>& lhs, + const Array2D>& rhs) { + return MatmulArray2DImpl>( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulC128); +} + std::unique_ptr> HloEvaluator::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { return MatmulArray2DImpl( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 803004225d2..dcd4129adcd 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -164,6 +164,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const Array2D& lhs, const Array2D& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr>> MatmulArray2D( + const Array2D>& lhs, + const Array2D>& rhs); + static std::unique_ptr>> MatmulArray2D( + const Array2D>& lhs, + const Array2D>& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 47a455ac3f4..ad21efa13c9 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -312,12 +312,13 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, absl::string_view label, - const DebugOptions& debug_options, bool show_backend_config, + const DebugOptions& debug_options, + HloRenderOptions hlo_render_options, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(label), debug_options_(debug_options), - show_backend_config_(show_backend_config), + hlo_render_options_(hlo_render_options), profile_(profile), filter_(std::move(filter)) {} @@ -384,7 +385,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const DebugOptions& debug_options_; - const bool show_backend_config_; + const HloRenderOptions hlo_render_options_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -565,7 +566,8 @@ bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { if (subcomp->IsFusionComputation()) { const HloInstruction* fusion = subcomp->FusionInstruction(); - if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) { + if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) || + !hlo_render_options_.show_fusion_subcomputations) { return false; } } @@ -1057,9 +1059,12 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGetDimensionSize: case HloOpcode::kSetDimensionSize: return kGray; + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kPartitionId: @@ -1130,7 +1135,8 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeBackendConfig( const HloInstruction* instr) { - if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { + if (!hlo_render_options_.show_backend_config || + instr->raw_backend_config_string().empty()) { return ""; } @@ -1601,14 +1607,14 @@ StatusOr RenderGraph(const HloComputation& computation, const DebugOptions& debug_options, RenderedGraphFormat format, const HloExecutionProfile* hlo_execution_profile, - bool show_backend_config) { + HloRenderOptions hlo_render_options) { tensorflow::mutex_lock lock(url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return Unavailable("Can't render as URL; no URL renderer was registered."); } string rendered_dot = - HloDotDumper(&computation, label, debug_options, show_backend_config, + HloDotDumper(&computation, label, debug_options, hlo_render_options, hlo_execution_profile, NodeFilter()) .Dump(); return WrapDotInFormat(rendered_dot, format); @@ -1616,7 +1622,7 @@ StatusOr RenderGraph(const HloComputation& computation, StatusOr RenderNeighborhoodAround( const HloInstruction& node, int radius, RenderedGraphFormat format, - bool show_backend_config, + HloRenderOptions hlo_render_options, const absl::flat_hash_set& boundary) { tensorflow::mutex_lock lock(url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { @@ -1629,7 +1635,7 @@ StatusOr RenderNeighborhoodAround( string rendered_dot = HloDotDumper(node.parent(), label, node.GetModule()->config().debug_options(), - show_backend_config, /*profile=*/nullptr, + hlo_render_options, /*profile=*/nullptr, MakeNodeRadiusAroundFilter(&node, radius, boundary)) .Dump(); return WrapDotInFormat(rendered_dot, format); @@ -1638,7 +1644,7 @@ StatusOr RenderNeighborhoodAround( StatusOr RenderAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, int64 max_nodes, RenderedGraphFormat format, - bool show_backend_config) { + HloRenderOptions hlo_render_options) { tensorflow::mutex_lock lock(url_renderer_mu); if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) { return FailedPrecondition( @@ -1660,7 +1666,7 @@ StatusOr RenderAllPathsFromTo(const HloInstruction& from, "NODES***

"); } string rendered_dot = - HloDotDumper(from.parent(), label, debug_options, show_backend_config, + HloDotDumper(from.parent(), label, debug_options, hlo_render_options, /*profile=*/nullptr, filter) .Dump(); return WrapDotInFormat(rendered_dot, format); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 324ac67a6dd..528de77e4e6 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -50,6 +50,14 @@ enum class RenderedGraphFormat { kUrl, }; +struct HloRenderOptions { + // Include the backend config string in the rendered graph. + bool show_backend_config = false; + + // Include the fusion subcomputations in the rendered graph. + bool show_fusion_subcomputations = true; +}; + // Renders an HLO module as a human-readable visual graph. // // Note that this only works well for relatively small graphs (no more than a @@ -61,7 +69,7 @@ StatusOr RenderGraph( const HloComputation& computation, absl::string_view label, const DebugOptions& debug_options, RenderedGraphFormat format, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_backend_config = false); + HloRenderOptions hlo_render_options = {}); // Like RenderGraph, but renders only nodes "near" the given node in the graph. // @@ -73,7 +81,7 @@ StatusOr RenderGraph( // will be omitted even if they are within the radius. StatusOr RenderNeighborhoodAround( const HloInstruction& node, int radius, RenderedGraphFormat format, - bool show_backend_config = false, + HloRenderOptions hlo_render_options = {}, const absl::flat_hash_set& boundary = {}); // Renders nodes on any of the paths from `from` to `to`. If there are more @@ -82,7 +90,7 @@ StatusOr RenderNeighborhoodAround( StatusOr RenderAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, int64 max_nodes, RenderedGraphFormat format, - bool show_backend_config = false); + HloRenderOptions hlo_render_options = {}); // Registers a function which implements RenderedGraphFormat::kUrl. // diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 27fac19587e..c02100debc3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -388,6 +388,24 @@ StatusOr> HloInstruction::CreateFromProto( proto.outfeed_config()); break; } + case HloOpcode::kAllGather: { + absl::optional channel_id; + if (proto.channel_id() > 0) { + channel_id = proto.channel_id(); + } + + TF_RET_CHECK(proto.dimensions_size() == 1) + << "AllGather cannot have more than 1 all-gather dimensions"; + TF_RET_CHECK(all_operands().size() == 1) + << "AllGather must have a single operand"; + int64 all_gather_dimension = proto.dimensions(0); + instruction = CreateAllGather( + shape, operands(0), all_gather_dimension, + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), + proto.constrain_layout(), channel_id, proto.use_global_device_ids()); + break; + } case HloOpcode::kAllReduce: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "AllReduce should have 1 called computation but sees " @@ -434,7 +452,8 @@ StatusOr> HloInstruction::CreateFromProto( /*channel_id=*/channel_id, split_dimension); break; } - case HloOpcode::kCollectivePermute: { + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: { std::vector> source_target_pairs( proto.source_target_pairs_size()); absl::optional channel_id; @@ -445,8 +464,17 @@ StatusOr> HloInstruction::CreateFromProto( source_target_pairs[i].first = proto.source_target_pairs(i).source(); source_target_pairs[i].second = proto.source_target_pairs(i).target(); } - instruction = CreateCollectivePermute(shape, operands(0), - source_target_pairs, channel_id); + + if (opcode == HloOpcode::kCollectivePermute) { + instruction = CreateCollectivePermute(shape, operands(0), + source_target_pairs, channel_id); + } else if (opcode == HloOpcode::kCollectivePermuteStart) { + instruction = CreateCollectivePermuteStart( + shape, operands(0), source_target_pairs, channel_id); + } else { + LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, " + << "but got " << HloOpcodeString(opcode); + } break; } case HloOpcode::kReplicaId: { @@ -787,6 +815,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kRoundNearestAfz: case HloOpcode::kBitcast: case HloOpcode::kCeil: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -929,6 +958,15 @@ HloInstruction::CreateReducePrecision(const Shape& shape, shape, operand, exponent_bits, mantissa_bits); } +/* static */ std::unique_ptr HloInstruction::CreateAllGather( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids) { + return absl::make_unique( + shape, operand, all_gather_dimension, replica_groups, constrain_layout, + channel_id, use_global_device_ids); +} + /* static */ std::unique_ptr HloInstruction::CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, @@ -955,7 +993,18 @@ HloInstruction::CreateCollectivePermute( const std::vector>& source_target_pairs, const absl::optional& channel_id) { return absl::make_unique( - shape, operand, source_target_pairs, channel_id); + HloOpcode::kCollectivePermute, shape, operand, source_target_pairs, + channel_id); +} + +/* static */ std::unique_ptr +HloInstruction::CreateCollectivePermuteStart( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs, + const absl::optional& channel_id) { + return absl::make_unique( + HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs, + channel_id); } /* static */ std::unique_ptr HloInstruction::CreateReplicaId() { @@ -1518,9 +1567,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1547,6 +1598,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -1900,6 +1952,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCopy: @@ -1997,9 +2050,11 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReducePrecision: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -2851,12 +2906,18 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); + case HloOpcode::kAllGather: + return visitor->HandleAllGather(this); case HloOpcode::kAllReduce: return visitor->HandleAllReduce(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); case HloOpcode::kCollectivePermute: return visitor->HandleCollectivePermute(this); + case HloOpcode::kCollectivePermuteStart: + return visitor->HandleCollectivePermuteStart(this); + case HloOpcode::kCollectivePermuteDone: + return visitor->HandleCollectivePermuteDone(this); case HloOpcode::kReplicaId: return visitor->HandleReplicaId(this); case HloOpcode::kPartitionId: @@ -3934,6 +3995,10 @@ const PaddingConfig& HloInstruction::padding_config() const { return Cast(this)->padding_config(); } +PaddingConfig* HloInstruction::mutable_padding_config() { + return Cast(this)->mutable_padding_config(); +} + int64 HloInstruction::slice_sizes(int64 dimension) const { return Cast(this)->slice_sizes(dimension); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 923138862a7..7a5d506b681 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -618,6 +618,16 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); + // Creates an all-gather op, which concats the operands of all participants + // along all_gather_dimension. The replica_groups, channel_id, and + // use_global_device_ids arguments are identical to those in all-reduce, + // except that the order of the group members determines the concatenation + // order of inputs from different participants. + static std::unique_ptr CreateAllGather( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids); + // Creates a cross replica reduction op. // // `reduction_computation`: the reduction function. @@ -671,7 +681,7 @@ class HloInstruction { const absl::optional& channel_id, const absl::optional& split_dimension = absl::nullopt); - // Creates a communication instructions that permutes data cross replicas. + // Creates a communication instruction that permutes data cross replicas. // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor @@ -681,6 +691,13 @@ class HloInstruction { const std::vector>& source_target_pairs, const absl::optional& channel_id); + // Creates a communication instruction that initiates the start of + // CollectivePermute. + static std::unique_ptr CreateCollectivePermuteStart( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs, + const absl::optional& channel_id); + // Creates an instruction that returns a U32 replica ID. static std::unique_ptr CreateReplicaId(); @@ -1800,6 +1817,7 @@ class HloInstruction { // Delegates to HloPadInstruction::padding_config. const PaddingConfig& padding_config() const; + PaddingConfig* mutable_padding_config(); // Delegates to HloDynamicSliceInstruction::slice_sizes. int64 slice_sizes(int64 dimension) const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index eb821d40e78..9c5a66f0040 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -556,6 +556,51 @@ bool HloCollectiveInstruction::IdenticalSlowPath( }); } +HloAllGatherInstruction::HloAllGatherInstruction( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids) + : HloCollectiveInstruction(HloOpcode::kAllGather, shape, {operand}, + replica_groups, constrain_layout, channel_id), + all_gather_dimension_(all_gather_dimension), + use_global_device_ids_(use_global_device_ids) {} + +std::vector HloAllGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); + result.push_back(StrCat("dimensions={", all_gather_dimension_, "}")); + if (use_global_device_ids_) { + result.push_back("use_global_device_ids=true"); + } + return result; +} + +std::unique_ptr +HloAllGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands[0], all_gather_dimension(), replica_groups(), + constrain_layout(), channel_id(), use_global_device_ids()); +} + +HloInstructionProto HloAllGatherInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + proto.add_dimensions(all_gather_dimension_); + return proto; +} + +bool HloAllGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + all_gather_dimension_ == casted_other.all_gather_dimension() && + use_global_device_ids() == casted_other.use_global_device_ids(); +} + HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, @@ -658,10 +703,10 @@ bool HloAllToAllInstruction::IdenticalSlowPath( } HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( - const Shape& shape, HloInstruction* operand, + HloOpcode opcode, const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs, const absl::optional& channel_id) - : HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id), + : HloChannelInstruction(opcode, shape, channel_id), source_target_pairs_(source_target_pairs) { AppendOperand(operand); } @@ -693,6 +738,9 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& eq_computations) const { + if (opcode() != other.opcode()) { + return false; + } const auto& casted_other = static_cast(other); return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && @@ -707,7 +755,7 @@ HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { return absl::make_unique( - shape, new_operands[0], source_target_pairs(), channel_id()); + opcode(), shape, new_operands[0], source_target_pairs(), channel_id()); } HloReverseInstruction::HloReverseInstruction(const Shape& shape, @@ -1819,8 +1867,14 @@ std::unique_ptr HloParameterInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return absl::make_unique(parameter_number_, shape, - name()); + auto clone = absl::make_unique(parameter_number_, + shape, name()); + if (parameter_replicated_at_leaf_buffers_ && + ShapeUtil::Equal(shape, this->shape())) { + clone->set_parameter_replicated_at_leaf_buffers( + *parameter_replicated_at_leaf_buffers_); + } + return clone; } HloGetTupleElementInstruction::HloGetTupleElementInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index eecd02d891e..6da01dc088e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -348,6 +348,38 @@ class HloCollectiveInstruction : public HloChannelInstruction { bool constrain_layout_; }; +class HloAllGatherInstruction : public HloCollectiveInstruction { + public: + explicit HloAllGatherInstruction( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids); + // Same as HloAllReduceInstruction::use_global_device_ids. + bool use_global_device_ids() const { return use_global_device_ids_; } + + // The dimension on which data from different participants are concatenated. + int64 all_gather_dimension() const { return all_gather_dimension_; } + + protected: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + int64 all_gather_dimension_; + bool use_global_device_ids_; +}; + class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( @@ -431,7 +463,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { class HloCollectivePermuteInstruction : public HloChannelInstruction { public: explicit HloCollectivePermuteInstruction( - const Shape& shape, HloInstruction* operand, + HloOpcode opcode, const Shape& shape, HloInstruction* operand, const std::vector>& source_target_pairs, const absl::optional& channel_id); @@ -674,7 +706,7 @@ class HloMapInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } - std::vector* mutable_dimensions() { return &dimensions_; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1377,6 +1409,7 @@ class HloPadInstruction : public HloInstruction { const PaddingConfig& padding_config); // Returns the padding configuration for a pad node. const PaddingConfig& padding_config() const { return padding_config_; } + PaddingConfig* mutable_padding_config() { return &padding_config_; } // Returns the padding value. const HloInstruction* padding_value() const { return operand(1); } HloInstruction* mutable_padding_value() { return mutable_operand(1); } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index bc1745a0791..5502665e886 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/base/casts.h" #include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" @@ -370,6 +371,11 @@ TokKind HloLexer::LexNumberOrPattern() { if (absl::SimpleAtoi(slice, &token_state_.int64_val)) { return TokKind::kInt; } + uint64 uint64_val; + if (absl::SimpleAtoi(slice, &uint64_val)) { + token_state_.int64_val = absl::bit_cast(uint64_val); + return TokKind::kInt; + } LOG(ERROR) << "Failed to parse int literal: " << slice; return TokKind::kError; } diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index ec048bef9e8..cb1b1d0dae4 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -203,6 +203,7 @@ HLO_MATCHER(Abs); HLO_MATCHER(Add); HLO_MATCHER(AddDependency); HLO_MATCHER(AfterAll); +HLO_MATCHER(AllGather); HLO_MATCHER(AllReduce); HLO_MATCHER(AllToAll); HLO_MATCHER(And); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index de65ed99303..9722d5c2b76 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -420,6 +420,8 @@ StatusOr HloModule::CreateModuleConfigFromShape( if (execution_options->num_partitions() > 0) { module_config.set_num_partitions(execution_options->num_partitions()); } + module_config.set_use_spmd_partitioning( + execution_options->use_spmd_partitioning()); if (execution_options->has_device_assignment()) { TF_ASSIGN_OR_RETURN(std::unique_ptr device_assignment, DeviceAssignment::Deserialize( diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index b31a9ae6ca5..964f83322a4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -128,6 +128,11 @@ class HloModuleConfig { } int64 num_partitions() const { return num_partitions_; } + void set_use_spmd_partitioning(bool use_spmd_partitioning) { + use_spmd_partitioning_ = use_spmd_partitioning; + } + bool use_spmd_partitioning() const { return use_spmd_partitioning_; } + // Return a string which unambiguously represents all the fields of this data // structure. Used for generating a cache key for storing the compiled // executable. @@ -199,6 +204,14 @@ class HloModuleConfig { std::vector>* mutable_dot_config() { return &dot_config_; } + const std::vector>>& layout_config() const { + return layout_config_; + } + + std::vector>>* mutable_layout_config() { + return &layout_config_; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -216,6 +229,10 @@ class HloModuleConfig { // The number of partitions (model parallelism) to compile this binary for. int64 num_partitions_ = 1; + // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA + // needs to partition the module. + bool use_spmd_partitioning_ = false; + // The target maximum parallelism at which to partition HLOs for parallel // execution on the CPU backend. int64 intra_op_parallelism_threads_ = -1; @@ -232,6 +249,9 @@ class HloModuleConfig { FusionConfigCollection fusion_config_collection_ = FusionConfigCollection::kOff; + // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto + // similar to backend config. + // Custom fusion configuration, where fusion_config_[c][v] control if node v // in computation c must be fused to all its consumers (true) or not (false). std::vector> fusion_config_; @@ -240,6 +260,10 @@ class HloModuleConfig { // how to convert dot operation v (sorted topologically and by computation) to // convolution. std::vector> dot_config_; + + // Layout configuration, where layout_config_[v][i] controls the layout + // decision i of operation v. + std::vector>> layout_config_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 2d66237de59..92359bcbdac 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -48,6 +48,7 @@ namespace xla { V(kAdd, "add", 2) \ V(kAddDependency, "add-dependency", 2) \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kAllGather, "all-gather", 1) \ V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \ V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \ V(kAtan2, "atan2", 2) \ @@ -62,6 +63,8 @@ namespace xla { V(kCholesky, "cholesky", 1) \ V(kClamp, "clamp", 3) \ V(kCollectivePermute, "collective-permute", 1) \ + V(kCollectivePermuteStart, "collective-permute-start", 1) \ + V(kCollectivePermuteDone, "collective-permute-done", 1) \ V(kClz, "count-leading-zeros", 1) \ V(kCompare, "compare", 2) \ V(kComplex, "complex", 2) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index a9c3cacc4c4..d52a60d2555 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -765,6 +765,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -850,6 +851,35 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateBitcastConvert(shape, operands[0])); break; } + case HloOpcode::kAllGather: { + optional>> tmp_groups; + optional> replica_group_ids; + optional channel_id; + optional> dimensions; + optional constrain_layout; + optional use_global_device_ids; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool, + &constrain_layout}; + attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool, + &use_global_device_ids}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllGather( + shape, operands[0], dimensions->at(0), replica_groups, + constrain_layout ? *constrain_layout : false, channel_id, + use_global_device_ids ? *use_global_device_ids : false)); + break; + } case HloOpcode::kAllReduce: { optional>> tmp_groups; optional to_apply; @@ -909,7 +939,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, split_dimension)); break; } - case HloOpcode::kCollectivePermute: { + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: { optional>> source_targets; attrs["source_target_pairs"] = { /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; @@ -928,9 +959,19 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, pairs[i].first = (*source_targets)[i][0]; pairs[i].second = (*source_targets)[i][1]; } - instruction = - builder->AddInstruction(HloInstruction::CreateCollectivePermute( - shape, operands[0], pairs, channel_id)); + if (opcode == HloOpcode::kCollectivePermute) { + instruction = + builder->AddInstruction(HloInstruction::CreateCollectivePermute( + shape, operands[0], pairs, channel_id)); + } else if (opcode == HloOpcode::kCollectivePermuteStart) { + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermuteStart(shape, operands[0], + pairs, channel_id)); + } else { + LOG(FATAL) << "Expect opcode to be CollectivePermute or " + "CollectivePermuteStart, but got " + << HloOpcodeString(opcode); + } break; } case HloOpcode::kReplicaId: { @@ -2569,14 +2610,10 @@ bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { std::is_same::value)) << "Unimplemented checking for ParsedElemT"; - ParsedElemT upper_bound; - if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { - upper_bound = std::numeric_limits::max(); - } else { - upper_bound = - static_cast(std::numeric_limits::max()); - } - if (value > upper_bound || value < 0) { + const uint64 unsigned_value = value; + const uint64 upper_bound = + static_cast(std::numeric_limits::max()); + if (unsigned_value > upper_bound) { // Value is out of range for LiteralNativeT. return Error(loc, StrCat("value ", value, " is out of range for literal's primitive type ", diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 7e66b4e648d..a687d0e1921 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1480,6 +1480,43 @@ ENTRY CRS { )" }, +// all-gather +{ +"AllGather", +R"(HloModule AllGather + +ENTRY AllGather { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1} +} + +)" +}, +// all-gather with constrained layout +{ +"AllGatherWithLayout", +R"(HloModule AllGather + +ENTRY AllGather { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, constrain_layout=true, dimensions={1} +} + +)" +}, +// all-gather with subgroups +{ +"AllGatherWithSubgroups", +R"(HloModule AllGatherWithSubgroups + +ENTRY AllGatherWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,64]{0,1} all-gather(input), replica_groups={{0,1},{2,3}}, dimensions={1} +} + +)", +/*replica_count=*/4, +}, // all-to-all { "AllToAll", @@ -1516,6 +1553,20 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } +)", +/*replica_count=*/4 +}, +// collective-permute-start and -done +{ +"CollectivePermuteStartAndDone", +R"(HloModule CollectivePermuteStartAndDone + +ENTRY CollectivePermuteStartAndDone { + input = f32[128,32]{0,1} parameter(0) + collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}} + ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1) +} + )", /*replica_count=*/4 }, @@ -1963,9 +2014,7 @@ TEST_F(HloParserTest, ConstantUnsignedUnderflow) { ROOT %constant = u64[] constant(-1) })"; auto result = ParseAndReturnUnverifiedModule(original); - EXPECT_NE(Status::OK(), result.status()); - ExpectHasSubstr(result.status().error_message(), - "is out of range for literal's primitive type U64"); + EXPECT_EQ(Status::OK(), result.status()); } TEST_F(HloParserTest, ConstantUnsignedOverflow) { @@ -1987,7 +2036,7 @@ TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { ROOT %constant = u64[] constant(9223372036854775808) })"; auto result = ParseAndReturnUnverifiedModule(original); - EXPECT_NE(Status::OK(), result.status()); + EXPECT_EQ(Status::OK(), result.status()); } TEST_F(HloParserTest, ConstantC64Overflow) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc new file mode 100644 index 00000000000..7fc05608800 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -0,0 +1,592 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace hlo_sharding_util { + +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count) { + int64 device = 0; + int64 count = 0; + for (auto& it : device_map) { + if (it.second > count) { + count = it.second; + device = it.first; + } + } + if (top_count != nullptr) { + *top_count = count; + } + return count > 0 ? absl::optional(device) : absl::optional(); +} + +Status AssignComputationDevice(HloComputation* computation, int64 device) { + VLOG(4) << "Assigning device " << device << " to " << computation->name() + << " computation"; + for (HloInstruction* instruction : computation->instructions()) { + if (!instruction->has_sharding()) { + VLOG(4) << "Assigning device " << device << " to " << instruction->name(); + instruction->set_device_sharding(device); + } + } + return Status::OK(); +} + +absl::optional GetMostOccurringDevice( + absl::Span instructions) { + std::map device_map; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(nullptr)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + } + return SelectDominantDevice(device_map, nullptr); +} + +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor) { + int64 instruction_count = 0; + std::map device_map; + for (HloComputation* computation : computations) { + for (HloInstruction* instruction : computation->instructions()) { + int64 count = 1; + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(&count)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + instruction_count += count; + } + } + int64 count; + absl::optional device = SelectDominantDevice(device_map, &count); + absl::optional dominant_device; + if (device) { + double factor = + static_cast(count) / static_cast(instruction_count); + if (factor >= dominant_factor) { + dominant_device = device; + } + } + return dominant_device; +} + +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions) { + if (sharding.IsTileMaximal()) { + return sharding; + } + const int64 rank = dimensions.size(); + std::vector tile_assignment_dim(rank); + for (int64 i = 0; i < rank; ++i) { + tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]); + } + Array tile_assignment = sharding.tile_assignment(); + tile_assignment.Reshape(tile_assignment_dim); + tile_assignment.Each([&](absl::Span indices, int64* value) { + std::vector src_indices(indices.size(), -1); + for (int64 i = 0; i < indices.size(); ++i) { + src_indices[dimensions[i]] = indices[i]; + } + *value = sharding.tile_assignment()(src_indices); + }); + return HloSharding::Tile(tile_assignment); +} + +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return sharding; + } + + // In case of a tiled sharding the reshaped sharding will be a valid if the + // reshape is composed from the following operations: + // * Adding or removing dimensions with size 1. + // * Merging consecutive dimensions where only the most major is sharded. + // * Splitting a dimension to consecutive dimensions. + // * Any reshaping of unsharded dimensions. + // Note that merge and split can happen consecutively on the same dimension, + // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024 + // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks + // to make supporting such cases easy. + const Shape tile_shape = sharding.TileShape(source_shape); + std::vector target_tile_assignment_dimensions; + std::vector source_dims_stack(source_shape.rank()); + std::vector target_dims_stack(target_shape.rank()); + std::vector sharding_tile_dims_stack(source_shape.rank()); + for (int64 i = 0; i < source_shape.rank(); ++i) { + source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i); + sharding_tile_dims_stack[i] = + sharding.tile_assignment().dim(source_shape.rank() - 1 - i); + } + for (int64 i = 0; i < target_shape.rank(); ++i) { + target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i); + } + while (!source_dims_stack.empty() || !target_dims_stack.empty()) { + if (target_dims_stack.empty()) { + if (Product(sharding_tile_dims_stack) != 1) { + return absl::nullopt; + } + break; + } + int64 s_size = 1; + int64 t_size = 1; + int64 s_partitions = 1; + if (!source_dims_stack.empty()) { + s_size = source_dims_stack.back(); + source_dims_stack.pop_back(); + s_partitions = sharding_tile_dims_stack.back(); + sharding_tile_dims_stack.pop_back(); + } + t_size = target_dims_stack.back(); + target_dims_stack.pop_back(); + if (s_partitions * Product(sharding_tile_dims_stack) == 1) { + // No more partitions left. + target_tile_assignment_dimensions.push_back(1); + continue; + } + if (s_size == t_size) { + // Same dimension. + target_tile_assignment_dimensions.push_back(s_partitions); + } else if (t_size == 1) { + // Trivial dimension added. + target_tile_assignment_dimensions.push_back(1); + source_dims_stack.push_back(s_size); + sharding_tile_dims_stack.push_back(s_partitions); + } else if (s_size == 1) { + // Trivial dimension removed. + if (s_partitions != 1) { + return absl::nullopt; + } + target_dims_stack.push_back(t_size); + } else if (s_size > t_size) { + // Dimension split. + if (s_size % t_size != 0 || t_size % s_partitions != 0) { + return absl::nullopt; + } + target_tile_assignment_dimensions.push_back(s_partitions); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(1); + } else { + // Dimension merge. Also merge the source dimension with the next, and + // process it next time. + if (s_size % s_partitions != 0) { + return absl::nullopt; + } + CHECK(!source_dims_stack.empty()); + if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) { + // If the next dimension to combine is sharded, we require that the + // current dimension's shard size to be 1. Otherwise, the new shard + // would be non-contiguous. + return absl::nullopt; + } + source_dims_stack.back() *= s_size; + sharding_tile_dims_stack.back() *= s_partitions; + target_dims_stack.push_back(t_size); + } + } + Array new_tile_assignment = sharding.tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ReverseSharding(const HloSharding& sharding, + absl::Span dimensions) { + if (sharding.IsTileMaximal() || dimensions.empty()) { + return sharding; + } + + Array new_tile_assignment(sharding.tile_assignment().dimensions()); + new_tile_assignment.Each([&](absl::Span indices, int64* device) { + std::vector original_indices(indices.begin(), indices.end()); + for (int64 d : dimensions) { + original_indices[d] = + new_tile_assignment.dim(d) - 1 - original_indices[d]; + } + *device = sharding.tile_assignment()(original_indices); + }); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims) { + CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal()); + CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims"; + // We optimize the tile assignment on the single dimension dim in a way to + // minimize communication among devices caused by the reshard: + // +---+---+ +---+---+ +-+-+-+-+ + // | | | | 0 | | | | | | + // | 0 | 1 | +-------+ | | | | | + // | | | reshape on | 1 | reshape on | | | | | + // +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3| + // | | | | 2 | | | | | | + // | 2 | 3 | +-------+ | | | | | + // | | | | 3 | | | | | | + // +---+---+ +---+---+ +-+-+-+-+ + + std::vector tile_dims(sharding.tile_assignment().num_dimensions(), 1); + // Handle ignore dimensions. + std::vector ignore_sizes; + int64 ignore_size = 1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + int64 size = sharding.tile_assignment().dim(i); + ignore_sizes.push_back(size); + tile_dims[i] = size; + ignore_size *= size; + } + } + + using Buckets = std::vector>; + Array buckets(ignore_sizes, + Buckets(sharding.tile_assignment().dim(dim))); + sharding.tile_assignment().Each( + [&](absl::Span index, int64 device) { + std::vector ignore_index; + for (int64 i = 0; i < index.size(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + ignore_index.push_back(index[i]); + } + } + buckets(ignore_index)[index[dim]].push_back(device); + }); + std::vector devices; + buckets.Each([&](absl::Span index, const Buckets& buckets) { + for (auto& bucket : buckets) { + devices.insert(devices.end(), bucket.begin(), bucket.end()); + } + }); + tile_dims[dim] = devices.size() / ignore_size; + Array tile_assignment(tile_dims); + tile_assignment.SetValues(devices); + return HloSharding::Tile(tile_assignment); +} + +bool ContainsTileSharding(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->has_sharding() && + !instruction->sharding().IsTileMaximal()) { + return true; + } + } + } + return false; +} + +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector output_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.offset_dims(), i)) { + output_tile_assignment_dims.push_back(1); + } else { + output_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(output_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo) { + if (output_sharding.IsTileMaximal()) { + return output_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + index_tile_assignment_dims.push_back( + output_sharding.tile_assignment().dim(i)); + } + } + Array new_tile_assignment = output_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { + if (hlo.sharding().IsTileMaximal()) { + return hlo.sharding(); + } + + const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers(); + std::vector tile_assignment_dims(hlo.shape().rank()); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i); + num_elements *= hlo.sharding().tile_assignment().dim(i); + } else { + tile_assignment_dims[i] = 1; + } + } + if (num_elements == hlo.sharding().tile_assignment().num_elements()) { + // Output sharding is only on non offset dimensions. We use output sharding + // to shard this gather op directly. + return hlo.sharding(); + } + + if (num_elements == 1) { + // Output sharding is only on offset dimensions. We do not shard this gather + // op. Return a tile maximal sharding with the first device in output + // sharding tile assignment. + return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin()); + } + + // Output sharding is on both offset and non offset dimensions. We shard the + // gather op only on non offset dimensions. + // For example: + // - the gather op has sharding [2,2]{0,1,2,3}, + // - first dimension is non offset dimension, + // - second dimension is offset dimension, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(hlo.shape().rank(), 0LL), + slice_limits(hlo.shape().rank()); + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + slice_limits[i] = hlo.sharding().tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.update_window_dims(), i)) { + index_tile_assignment_dims.push_back( + data_sharding.tile_assignment().dim(i)); + } + } + if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { + index_tile_assignment_dims.push_back(1); + } + Array new_tile_assignment = data_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector data_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.update_window_dims(), i)) { + data_tile_assignment_dims.push_back(1); + } else { + data_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(data_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + // Only shard on first "number of scatter_window_dims" dimensions. + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + int64 num_elements = 1; + int64 index_dim = 0; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + num_elements *= index_sharding.tile_assignment().dim(index_dim); + index_dim++; + } + } + if (num_elements == index_sharding.tile_assignment().num_elements()) { + // Index sharding is only on scatter_window_dims. We use this index sharding + // directly. + return index_sharding; + } + + // Index sharding is only on update_window_dims. We do not shard this scatter + // op. Return a tile maximal sharding with the first device in index sharding + // tile assignment. + if (num_elements == 1) { + return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin()); + } + + const int64 index_rank = hlo.operand(1)->shape().rank(); + std::vector slice_starts(index_rank, 0LL), slice_limits(index_rank); + for (int64 i = 0; i < index_rank; ++i) { + if (i < index_dim) { + slice_limits[i] = index_sharding.tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + index_sharding.tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + const int64 data_rank = hlo.operand(2)->shape().rank(); + std::vector tile_assignment_dims(data_rank, 1LL); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + CHECK_LT(i, data_rank); + tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i); + num_elements *= data_sharding.tile_assignment().dim(i); + } + } + if (num_elements == data_sharding.tile_assignment().num_elements()) { + // Data sharding is only on scatter_window_dims. We use this data sharding + // directly. + return data_sharding; + } + + if (num_elements == 1) { + // Data sharding is only on update_window_dims. We do not shard this + // scatter op. Return a tile maximal sharding with the first device in + // data sharding tile assignment. + return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin()); + } + + // Data sharding is on both update_window_dims and scatter_window_dims. We + // shard the scatter op only on scatter_window_dims. For example: + // - the scatter data has sharding [2,2]{0,1,2,3}, + // - first dimension is scatter_window_dims, + // - second dimension is update_window_dims, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(data_rank, 0LL); + Array tile_assignment = + data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims); + return HloSharding::Tile(tile_assignment); +} + +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter) { + auto computation = scatter.to_apply(); + // We only handle computations with 2 parameters and only 1 calculation. + if (computation->instruction_count() != 3) { + return Status( + tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation with 2 parameters and only 1 " + "calculation"); + } + + auto root_instruction = computation->root_instruction(); + if (root_instruction->opcode() == HloOpcode::kAdd || + root_instruction->opcode() == HloOpcode::kOr) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMultiply || + root_instruction->opcode() == HloOpcode::kAnd) { + return std::make_pair(HloInstruction::CreateConstant( + LiteralUtil::One(scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMaximum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMinimum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } + + return Status(tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation which is " + "add/or/multiply/add/min/max"); +} + +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices) { + std::vector devices; + if (sharding.IsReplicated()) { + for (int64 d : available_devices) { + if (!HloSharding::IsReservedDevice(d)) { + devices.push_back(d); + } + } + return devices; + } + + for (int64 i : available_devices) { + if (sharding.UsesDevice(i)) { + devices.push_back(i); + } + } + DCHECK(std::all_of(sharding.tile_assignment().begin(), + sharding.tile_assignment().end(), [&](int64 device) { + return std::find(available_devices.begin(), + available_devices.end(), + device) != available_devices.end(); + })); + return devices; +} + +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h new file mode 100644 index 00000000000..562f6d1420d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -0,0 +1,149 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace hlo_sharding_util { + +// Given a map, selects the device with higher +// occurrence count (if any). If top_count in not nullptr, it will receive the +// count of the dominant device returned. +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count); + +// Assigns all the instructions of a computation, to a given device. +// This API does not recurse into called computations, and does not assign +// instructions which already have sharding. +Status AssignComputationDevice(HloComputation* computation, int64 device); + +// Given an instruction container, returns the device which is most commonly +// occurring among the instructions. +absl::optional GetMostOccurringDevice( + absl::Span instructions); + +// Given a set of computations, tries to extract the dominant device. A device +// is dominant if the combined occurrence among all the instructions of the +// input computations, is greater/equal than/to dominant_factor (real number +// from 0 to 1). +// This API does not recurse into called computations. +// If no device exists that satisfies the condition, the returned optional will +// hold no value. +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor); + +// Returns the HloSharding with the tile dimensions and tile assignment +// transposed based on the specified dimension numbers. In case of a tile +// maximal sharding returns the original sharding. +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions); + +// Returns the HloSharding with the tile shape reshaped based on the source and +// target shapes and the tile assignment adjusted to correspond to the new tile +// shape or absl::nullopt if the resulting reshape would create an invalid +// sharding (non continuous or non uniformly sized tiles). In case of a tile +// maximal sharding returns the original sharding. +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding); + +// Returns the HloSharding with the tile dimensions and tile assignment +// reversed based on the specified dimension numbers. In case of a tile +// maximal sharding returns the original sharding. +HloSharding ReverseSharding(const HloSharding& sharding, + absl::Span dimensions); + +// Returns a sharding tiled on unique dimension dim by reshaping the tile +// assignment of the sharding argument. Only dimensions in the dims span +// argument are considered for reshaping, the others are ignored. +// Assumptions: sharding is tile sharded, and dim must be included in dims. +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims); + +// Returns true if the provided module includes one or more instructions with +// a tile sharding. +bool ContainsTileSharding(const HloModule& module); + +// Returns the preferred output sharding for a gather op based on the sharding +// of the indces. +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns the preferred index sharding for a gather op based on the sharding +// of the output. +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo); + +// Returns a new HloSharding for a gather op so that only non offset dimensions +// are sharded. Assume "result" is returned by this function. It is ensured that +// "GetIndexSharding(result, hlo)" will have the same number of elements as +// "result". +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); + +// Returns the preferred index sharding for a scatter op based on the sharding +// of the data. +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo); + +// Returns the preferred data sharding for a scatter op based on the sharding +// of the index. +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns a new index sharding for a scatter op so that we only shard on first +// "number of scatter_window_dims" dimensions. Assume "result" is returned by +// this function. It is ensured that "ScatterDataSharding(result, hlo)" will +// have the same number of elements as "result". +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo); + +// Returns a new data sharding for a scatter op so that we only shard on +// scatter_window_dims. Assume "result" is returned by this function. It is +// ensured that "ScatterIndexSharding(result, hlo)" will have the same number of +// elements as "result". +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo); + +// Returns an identity value and an HloOpcode for reduce computation of scatter +// instruction. +// - If computation is add/or, return 0/false with corresponding op code; +// - If computation is multiply/and, return 1/true with corresponding op code. +// - If computation is min/max, return max value/min value with corresponding op +// code. +// - Otherwise, return error status. +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter); + +// Given a sharding and a list of devices in the topology, return a +// list of the devices that `sharding` applies to. +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices); + +} // namespace hlo_sharding_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc new file mode 100644 index 00000000000..02496c75965 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace hlo_sharding_util { +namespace { + +TEST(HloShardingUtilTest, TransposeShardingReplicated) { + EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), + HloSharding::Replicate()); +} + +TEST(HloShardingUtilTest, TransposeShardingTiled) { + HloSharding input = HloSharding::Tile(Array4D({{{{0, 1}}, {{2, 3}}}})); + HloSharding output = + HloSharding::Tile(Array4D({{{{0}, {2}}}, {{{1}, {3}}}})); + EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output); +} + +TEST(HloShardingUtilTest, ReshapeShardingMaximal) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::AssignDevice(7); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {20, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7, 5, 3}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 15, 2, 14}); + Array sharding_array({2, 1, 1, 1}); + sharding_array(0, 0, 0, 0) = 0; + sharding_array(1, 0, 0, 0) = 1; + HloSharding sharding = HloSharding::Tile(sharding_array); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {3, 1, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 1, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array4D({{{{0}, {1}}}})); + HloSharding output_sharding = + HloSharding::Tile(Array4D({{{{0}}, {{1}}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) { + Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16}); + Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) { + Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = ReshapeSharding(shape, shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingScalar) { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0}, {1}, {2}, {3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0, 2, 1, 3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}); + EXPECT_EQ( + result.tile_assignment(), + Array3D({{{0}}, {{1}}, {{2}}, {{3}}, {{4}}, {{5}}, {{6}}, {{7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0}, {1}, {4}, {5}, {2}, {3}, {6}, {7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 4, 6, 1, 3, 5, 7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) { + // Tile sharding in batch dimension, i.e. + // sharding={devices[2,2,2]0,1,2,3,4,5,6,7,8}. + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + // Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0. + HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2, + /*dims=*/{1, 2}); + // Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two + // non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually + // reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}. + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}})); +} + +} // namespace +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index acc077ab12d..e57c8a83b23 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -91,8 +91,7 @@ string HloValue::ToShortString() const { return absl::StrFormat( "<%d %s%s%s%s>", id(), instruction()->name(), instruction()->shape().IsTuple() ? index().ToString() : "", - is_phi() ? " (phi)" : "", - has_color() ? StrCat(" @", color().value()) : ""); + is_phi() ? " (phi)" : "", has_color() ? StrCat(" @", color()) : ""); } string HloValue::ToString(int indent) const { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0911af10f38..4661b8fd9e3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -74,7 +74,6 @@ Status CheckParameterCount(const HloInstruction* calling_instruction, } return Status::OK(); } - } // namespace Status ShapeVerifier::Preprocess(HloInstruction* hlo) { @@ -236,6 +235,40 @@ static Status CheckReplicaGroups(HloInstruction* hlo) { return Status::OK(); } +Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) { + auto ag = Cast(hlo); + TF_RETURN_IF_ERROR(CheckReplicaGroups(ag)); + TF_RET_CHECK(ag->all_gather_dimension() >= 0); + TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank()); + TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank()); + if (ag->use_global_device_ids() && ag->replica_groups().empty()) { + return InternalError( + "Replica group must be specified when use_global_device_ids is true"); + } + + int64 shard_count = CeilOfRatio( + ag->shape().dimensions(ag->all_gather_dimension()), + ag->operand(0)->shape().dimensions(ag->all_gather_dimension())); + if (ag->channel_id().has_value()) { + if (ag->use_global_device_ids()) { + TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); + } else { + if (ag->replica_groups().empty() || + ag->replica_groups()[0].replica_ids_size() != 1) { + return InternalError( + "Replica group size must be 1 when use_global_device_ids is " + "false if the all-gather is also cross-partition"); + } + } + } else if (!ag->replica_groups().empty()) { + // Cross-replica all-gather: shard count is subgroup size. + TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); + } + return CheckShape(ag, ShapeInference::InferAllGatherShape( + ag->operand(0)->shape(), ag->all_gather_dimension(), + shard_count)); +} + Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { TF_RETURN_IF_ERROR(CheckReplicaGroups(crs)); @@ -298,7 +331,9 @@ Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); } -Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { +namespace { + +Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) { // A source or target cannot appear twice in the collective-permute's // source-target pairs. absl::flat_hash_set seen_sources; @@ -317,10 +352,30 @@ Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { p.second, hlo->ToString()); } } + return Status::OK(); +} + +} // namespace + +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } +Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); + return CheckShape( + hlo, ShapeUtil::MakeTupleShape( + {hlo->operand(0)->shape(), hlo->operand(0)->shape(), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})})); +} + +Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) { + return CheckShape( + hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -628,9 +683,11 @@ Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { shape_size_function_(bitcast->operand(0)->shape())) { return InternalError( "Bitcast cannot have different shape sizes of output (%d) and operand " - "(%d)", + "(%d) (%s) (%s)", shape_size_function_(bitcast->shape()), - shape_size_function_(bitcast->operand(0)->shape())); + shape_size_function_(bitcast->operand(0)->shape()), + bitcast->shape().ToString(true), + bitcast->operand(0)->shape().ToString(true)); } return Status::OK(); } @@ -697,11 +754,7 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { } for (HloInstruction* fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); - // Since fusion buffers aren't materialized, fusion parameters will not have - // the same memory space as the fusion operand. - if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape(), - /*minor_to_major_only=*/false, - /*ignore_memory_space=*/true)) { + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( "Shape mismatch between parameter number %d and its operand in " "%s.", @@ -1343,32 +1396,60 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, return Status::OK(); } -// Checks CopyStart and CopyDone nodes. -Status VerifyAsynchronousCopies(const HloModule& module) { +Status VerifySingleUser(const HloInstruction* instruction, + HloOpcode expected_user) { + TF_RET_CHECK(instruction->users().size() == 1) + << "The " << HloOpcodeString(instruction->opcode()) + << " instruction requires one consumer, found " + << instruction->users().size(); + + const HloInstruction* user = instruction->users().front(); + TF_RET_CHECK(user->opcode() == expected_user) + << "The consumer of a " << HloOpcodeString(instruction->opcode()) + << " instruction needs to be " << HloOpcodeString(expected_user) + << ", found " << HloOpcodeString(user->opcode()); + return Status::OK(); +} + +Status VerifySingleOperand(const HloInstruction* instruction, + HloOpcode expected_operand) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "The " << HloOpcodeString(instruction->opcode()) + << " instruction requires one consumer, found " + << instruction->users().size(); + + const HloInstruction* operand = instruction->operand(0); + TF_RET_CHECK(operand->opcode() == expected_operand) + << "The operand of a " << HloOpcodeString(instruction->opcode()) + << " instruction needs to be " << HloOpcodeString(expected_operand) + << ", found " << HloOpcodeString(operand->opcode()); + return Status::OK(); +} + +// Checks asynchronous instruction pairs. +Status VerifyAsynchronousInstructionPairs(const HloModule& module) { // CopyStart must have a single CopyDone user. for (const HloComputation* computation : module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { switch (instruction->opcode()) { case HloOpcode::kCopyStart: { - TF_RET_CHECK(instruction->users().size() == 1) - << "CopyStart instruction requires one consumer, found " - << instruction->users().size(); - const HloInstruction* copy_done = instruction->users().front(); - TF_RET_CHECK(copy_done->opcode() == HloOpcode::kCopyDone) - << "The consumer of a CopyStart instruction needs to be " - "CopyDone, found " - << HloOpcodeString(copy_done->opcode()); + TF_RETURN_IF_ERROR( + VerifySingleUser(instruction, HloOpcode::kCopyDone)); break; } case HloOpcode::kCopyDone: { - TF_RET_CHECK(instruction->operands().size() == 1) - << "CopyDone instruction requires one operand, found " - << instruction->operands().size(); - const HloInstruction* copy_start = instruction->operand(0); - TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart) - << "The operand of a CopyDone instruction needs to be CopyStart, " - "found " - << HloOpcodeString(copy_start->opcode()); + TF_RETURN_IF_ERROR( + VerifySingleOperand(instruction, HloOpcode::kCopyStart)); + break; + } + case HloOpcode::kCollectivePermuteStart: { + TF_RETURN_IF_ERROR( + VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone)); + break; + } + case HloOpcode::kCollectivePermuteDone: { + TF_RETURN_IF_ERROR(VerifySingleOperand( + instruction, HloOpcode::kCollectivePermuteStart)); break; } default: @@ -1783,7 +1864,7 @@ StatusOr HloVerifier::Run(HloModule* module) { } TF_RETURN_IF_ERROR(VerifyHloStructure(module)); - TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module)); + TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); TF_RETURN_IF_ERROR(VerifyChannels(*module)); std::unique_ptr shape_verifier = diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 2e83361a591..85b02e0518c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -56,9 +56,12 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCholesky(HloInstruction* hlo) override; Status HandleTriangularSolve(HloInstruction* hlo) override; + Status HandleAllGather(HloInstruction* hlo) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status HandleCollectivePermuteStart(HloInstruction* hlo) override; + Status HandleCollectivePermuteDone(HloInstruction* hlo) override; Status HandlePartitionId(HloInstruction* hlo) override; Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index e2c363e40c5..d9709c50df9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -710,7 +710,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) { ASSERT_FALSE(status.ok()); EXPECT_THAT( status.error_message(), - HasSubstr("CopyStart instruction requires one consumer, found 2")); + HasSubstr("copy-start instruction requires one consumer, found 2")); } TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { @@ -730,8 +730,8 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), - HasSubstr("The operand of a CopyDone instruction needs to be " - "CopyStart, found tuple")); + HasSubstr("The operand of a copy-done instruction needs to be " + "copy-start, found tuple")); } TEST_F(HloVerifierTest, IotaNonArrayResult) { @@ -1134,5 +1134,87 @@ TEST_F(HloVerifierTest, CollectiveChannelVerifier) { HasSubstr("used for different types of channel instructions")); } +TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndDone { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndDoneWrongType { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to " + "(f32[2,3], f32[2,3], u32[], u32[])")); +} + +TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndMultipleDone { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + HasSubstr("collective-permute-start instruction requires one consumer, " + "found 2")); +} + +TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteDoneNoCollectivePermuteStart { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + p1 = f32[2,3]{1,0:S(1)} parameter(1) + p2 = u32[] parameter(2) + p3 = u32[] parameter(3) + tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3) + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("The operand of a collective-permute-done instruction " + "needs to be collective-permute-start, found tuple")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 99242c9ca21..02966cc2bf2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -145,9 +145,12 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCholesky: case HloOpcode::kConditional: case HloOpcode::kConvolution: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: @@ -501,7 +504,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { while (true) { auto next_entry = fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); - auto instruction = next_entry.first; + HloInstruction* instruction = next_entry.first; if (instruction == nullptr) { break; } @@ -511,12 +514,14 @@ StatusOr InstructionFusion::Run(HloModule* module) { continue; } + VLOG(5) << "Considering fusion of: " << instruction->ToString(); std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); if (!operand->IsFusible()) { + VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible"; continue; } @@ -690,6 +695,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, if (FusionWouldDuplicate(*producer, *consumer) && (!may_duplicate_ || is_expensive_(*producer)) && !IsAlwaysDuplicable(*producer)) { + VLOG(4) << "Stopping: fusion may duplicate operand (" + << producer->ToString() << ") , and this is expensive"; return false; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index a67c677bd03..82c30f1a710 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -951,7 +951,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { if (!Shape::Equal() .IgnoreDynamicDimension() .MinorToMajorOnlyInLayout()(instruction_subshape, - buffer->shape())) { + buffer->shape()) && + instruction->opcode() != HloOpcode::kBitcast) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -1798,13 +1799,6 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // potential bugs in the layout assignment pass that may accidentally use the // existing layout. for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString()); - } // Some instructions carry mandatory layouts in their shape. if (instruction->opcode() != HloOpcode::kInfeed && !IsLayoutConstrainedCustomCall(instruction) && @@ -2179,6 +2173,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConditional: case HloOpcode::kConvert: case HloOpcode::kCos: + case HloOpcode::kAllGather: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: @@ -2239,6 +2234,8 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kCall: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kConstant: case HloOpcode::kConvolution: case HloOpcode::kCopy: diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 304a80c7a52..6e575247e6b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -814,27 +814,6 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } -TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { - auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( - {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); - builder.AddInstruction( - HloInstruction::CreateBitcast(constant0->shape(), constant0)); - auto m = CreateNewVerifiedModule(); - m->AddEntryComputation(builder.Build()); - - ComputationLayout computation_layout( - m->entry_computation()->ComputeProgramShape()); - LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(m.get()).status(); - EXPECT_FALSE(error_status.ok()); - EXPECT_THAT( - error_status.error_message(), - ::testing::HasSubstr( - "Unexpected bitcast operation seen during layout assignment")); -} - TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { // Pin non matching layouts to parameter and root. const char* module_str = R"( diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 39399df7ad8..cabcc8e06ee 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -64,6 +64,7 @@ cc_library( srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index 453a5cd84b2..f7808773592 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -58,7 +58,7 @@ ENTRY while3 { CompileAndVerifyIr(hlo_string, R"( ; CHECK-LABEL: @body(i8* %retval -; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] +; CHECK: %[[add_result:.*]] = fadd reassoc nsz contract float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], align 4, !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; ; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index da0dbf94ddd..278aa3e1696 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -373,6 +374,28 @@ llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, return logical_linear_index; } +llvm::Value* IrArray::Index::Linearize( + const std::vector& dynamic_dims, + llvm::IRBuilder<>* builder) const { + // Each dimension is multiplied by the product of the sizes of all + // earlier dimensions and added to the accumulator logical_linear_index. + CHECK_EQ(size(), dynamic_dims.size()); + llvm::Value* logical_linear_index = GetConstantWithIndexType(0); + llvm::Value* multiplier = GetConstantWithIndexType(1); + for (ssize_t i = size() - 1; i >= 0; --i) { + llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "", + /*HasNUW=*/true, /*HasNSW=*/true); + addend = builder->CreateZExtOrTrunc(addend, index_type_); + logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + if (i) { + multiplier = builder->CreateMul(multiplier, dynamic_dims[i], + /*Name=*/"multiplier"); + } + } + return logical_linear_index; +} + llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, llvm::IRBuilder<>* b, absl::string_view name, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index e838c4a0534..c71654f5294 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -155,6 +155,10 @@ class IrArray { llvm::Value* Linearize(absl::Span dimensions, llvm::IRBuilder<>* builder) const; + // Linearizes the index into the given dynamic dimensions. + llvm::Value* Linearize(const std::vector& dynamic_dims, + llvm::IRBuilder<>* builder) const; + llvm::Type* GetType() const { return index_type_; } llvm::Constant* GetConstantWithIndexType(int64 c) const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 4c9a8d3e004..6375bf7341f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" @@ -90,7 +91,9 @@ llvm::CallInst* EmitCallToIntrinsic( llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -103,7 +106,9 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -287,7 +292,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, llvm::AllocaInst* alloca = b->CreateAlloca(type, element_count, AsStringRef(name)); if (alignment != 0) { - alloca->setAlignment(llvm::MaybeAlign(alignment)); + alloca->setAlignment(llvm::Align(alignment)); } return alloca; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 83be4334269..b6b3b2dd8b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -35,6 +35,14 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* b) : body_emitter_(body_emitter), shape_(shape), b_(b) {} +LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, + std::vector dynamic_dims, + llvm::IRBuilder<>* b) + : LoopEmitter::LoopEmitter(body_emitter, shape, b) { + CHECK_EQ(dynamic_dims.size(), shape_.dimensions_size()); + dynamic_dims_ = std::move(dynamic_dims); +} + LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* b) : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { @@ -84,6 +92,43 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } } +IrArray::Index LoopEmitter::EmitStaticIndex(ForLoopNest* loop_nest, + llvm::Type* index_type) { + // Create loop nest with one for-loop for each dimension of the target shape. + // Loops are added from outermost to innermost order with the ForLoopNest + // class so emit loops in order from most-major dimension down to most-minor + // dimension (of the target shape). + std::vector array_multi_index(shape_.dimensions_size()); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); + std::unique_ptr loop = loop_nest->AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/absl::StrFormat("dim.%d", dimension)); + array_multi_index[dimension] = loop->GetIndVarValue(); + } + return IrArray::Index(array_multi_index, shape_, index_type); +} + +IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest, + llvm::Type* index_type) { + CHECK_EQ(shape_.is_dynamic(), true); + // Create loop nest with one for-loop for each dynamic dimensions. + // Loops are added from outermost to innermost order with the ForLoopNest + // class so emit loops in order from most-major dimension down to most-minor + // dimension (of the target shape). + std::vector array_multi_index(shape_.dimensions_size()); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); + std::unique_ptr loop = loop_nest->AddLoop( + /*suffix=*/absl::StrFormat("dim.%d", dimension), + /*start_index=*/llvm::ConstantInt::get(index_type, 0), + /*end_index=*/dynamic_dims_[dimension]); + array_multi_index[dimension] = loop->GetIndVarValue(); + } + return IrArray::Index(array_multi_index, shape_, index_type); +} + std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); @@ -93,21 +138,11 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {IrArray::Index(index_type)}; } - // Create loop nest with one for-loop for each dimension of the target shape. - // Loops are added from outermost to innermost order with the ForLoopNest - // class so emit loops in order from most-major dimension down to most-minor - // dimension (of the target shape). ForLoopNest loop_nest(loop_name, b_); - std::vector array_multi_index(shape_.dimensions_size()); - for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { - int64 dimension = LayoutUtil::Major(shape_.layout(), i); - std::unique_ptr loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/absl::StrFormat("dim.%d", dimension)); - array_multi_index[dimension] = loop->GetIndVarValue(); - } - IrArray::Index array_index(array_multi_index, shape_, index_type); + + IrArray::Index array_index = dynamic_dims_.empty() + ? EmitStaticIndex(&loop_nest, index_type) + : EmitDynamicIndex(&loop_nest, index_type); // Set IR builder insertion point to the loop body basic block of the // innermost loop. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index a537c00066b..008205a642a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -42,6 +43,12 @@ class LoopEmitter { LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* b); + + // Constructs a LoopEmitter from an body_emitter that generates + // element of the given target array in the dynamic dimension. + LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, + std::vector dynamic_dims, llvm::IRBuilder<>* b); + // Constructs a LoopEmitter from an element generator that generates each // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, @@ -81,11 +88,21 @@ class LoopEmitter { // The shape that the emitted loop iterates through. Shape shape_; + // Dynamic dimensions that emitted loop iterates through. Generate the + // loop based on the dynamic dimensions if this vector is not empty. + std::vector dynamic_dims_; + // Points to the exit block of the emitted loop. If the given shape is // scalar, no loops are emitted and exit_bb_ is nullptr in that case. llvm::BasicBlock* exit_bb_; llvm::IRBuilder<>* b_; + + private: + IrArray::Index EmitStaticIndex(ForLoopNest* loop_nest, + llvm::Type* index_type); + IrArray::Index EmitDynamicIndex(ForLoopNest* loop_nest, + llvm::Type* index_type); }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index ef8ddfc1a76..c80646e0c70 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -112,6 +112,8 @@ ExecutionOptions CreateExecutionOptions( } execution_options.set_num_replicas(build_options.num_replicas()); execution_options.set_num_partitions(build_options.num_partitions()); + execution_options.set_use_spmd_partitioning( + build_options.use_spmd_partitioning()); if (build_options.has_device_assignment()) { TF_CHECK_OK(build_options.device_assignment().Serialize( execution_options.mutable_device_assignment())); diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index e1f56727bd2..d937d53d550 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -34,7 +34,7 @@ LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { string color_string; if (has_color()) { - color_string = absl::StrCat(" @", color().value()); + color_string = absl::StrCat(" @", color()); } return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","), "](#", id(), color_string, ")"); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 8752e870bb7..0ed72f51754 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -16,53 +16,52 @@ limitations under the License. #include "tensorflow/compiler/xla/service/memory_space_assignment.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/core/lib/math/math_util.h" namespace xla { namespace { // Define a dummy chunk for chunks that will be allocated in the default memory // space and for keeping track of number of asynchronous copies. const HeapSimulator::Chunk kDummyChunk{-1, -1}; +// This variable is used by the cost analysis in estimating how many times each +// while loop will execute. Nested loops will be assumed to have executed +// pow(kWhileExecutionCount, nesting_level) times. +const int kWhileExecutionCount = 5; -// Returns a heuristic value that captures how much putting this tensor to -// the alternate memory would help if the op is memory bound, or otherwise -// how far off is the op to memory boundedness. The larger this number, the -// higher priority it will be placed in the alternate memory. -float GetAlternateMemoryBenefit( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, +} // namespace + +float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem) { + float elapsed_time_due_to_alternate_mem) const { float elapsed_time_due_to_compute = - cost_analysis.GetInstructionElapsedDueToCompute(instruction); + GetInstructionElapsedDueToCompute(instruction); float elapsed_time_due_to_memory = - cost_analysis.GetInstructionElapsedDueToMemory(instruction); + GetInstructionElapsedDueToMemory(instruction); if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { // Memory bound, return how much alternate memory is better. - return elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem; + int while_nest_level = CalculateWhileLoopNestLevel(&instruction); + return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * + tensorflow::MathUtil::IPow(kWhileExecutionCount, + while_nest_level); } else { // Compute bound, return how far off are we to memory boundedness. return elapsed_time_due_to_memory - elapsed_time_due_to_compute; } } -// Returns a heuristic value of memory boundedness for the given BufferInterval. -// The larger this number, the higher priority it will be placed in the -// alternate memory. -float GetMemoryBoundedness( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { +float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { const HloInstruction& defining_instruction = *interval.buffer->defining_instruction(); - float alternate_mem_benefit = - GetAlternateMemoryBenefit(cost_analysis, defining_instruction, - cost_analysis.GetInstructionElapsedDueToMemory( - defining_instruction, - /*operand_in_alternate_mem=*/{}, - /*output_in_alternate_mem=*/true)); + float alternate_mem_benefit = GetAlternateMemoryBenefit( + defining_instruction, + GetInstructionElapsedDueToMemory(defining_instruction, + /*operand_in_alternate_mem=*/{}, + /*output_in_alternate_mem=*/true)); for (const HloUse& use : interval.buffer->uses()) { float use_alternate_mem_benefit = GetAlternateMemoryBenefit( - cost_analysis, *use.instruction, - cost_analysis.GetInstructionElapsedDueToMemory(*use.instruction, - use.operand_number)); + *use.instruction, + GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number)); // If the benefit is positive (memory bound), add it to this buffer's // benefit. If the benefit is negative (compute bound), calculate the // maximum. @@ -77,7 +76,7 @@ float GetMemoryBoundedness( // Get performance slowdown in seconds of prefetching current BufferInterval // causing to other BufferIntervals. float alternate_mem_slowdown = - cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size); + GetInstructionElapsedDueToMemorySlowdown(interval.size); // Scale the slowdown based on the time of this buffer. We would want earlier // buffers have lower slowdown values, because they are less likely to overlap @@ -86,13 +85,28 @@ float GetMemoryBoundedness( // for early HLOs, and full slowdown for mid-to-late HLOs. // TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with // more HLOs have higher slowdown, and vice versa. - float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime(); + float scale = interval.start * 1.0 / GetScheduleEndTime(); alternate_mem_slowdown *= scale; return alternate_mem_benefit - alternate_mem_slowdown; } -} // namespace +int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( + const HloInstruction* instruction) const { + int nest_level = 0; + const HloComputation* computation = instruction->parent(); + while (!computation->IsEntryComputation()) { + auto node = call_graph_.GetNode(computation); + auto callsites = node.caller_callsites(); + CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; + auto callsite = callsites[0]; + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + ++nest_level; + } + computation = callsite.instruction()->parent(); + } + return nest_level; +} float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const { @@ -207,29 +221,30 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( const MemorySpaceAssignmentCostAnalysis& cost_analysis, float min_async_copy_to_overlap_ratio, float max_async_copy_to_overlap_ratio) - : cost_analysis_(cost_analysis), + : elapsed_time_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0), + while_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0), + cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) { instruction_schedule_ = &cost_analysis_.hlo_live_range().instruction_schedule(); - // First create a vector of elapsed times of HLO instructions. - std::vector instructions_elapsed_time(instruction_schedule_->size(), - 0.0); + // Create a vector of elapsed times and while nesting levels of HLO + // instructions. for (const auto& instruction_and_logical_time : *instruction_schedule_) { float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; - if (logical_time >= instructions_elapsed_time.size()) { - instructions_elapsed_time.resize(logical_time + 1, 0.0); + if (logical_time >= elapsed_time_.size()) { + elapsed_time_.resize(logical_time + 1, 0.0); + while_nest_level_.resize(logical_time + 1, 0); } - instructions_elapsed_time[logical_time] = elapsed_time; - } - // As an optimization, create a cumulative sum vector of elapsed time. - float cumsum = 0.0; - for (float elapsed_time : instructions_elapsed_time) { - cumsum += elapsed_time; - elapsed_time_cumsum_.push_back(cumsum); + elapsed_time_[logical_time] = elapsed_time; + while_nest_level_[logical_time] = + cost_analysis_.CalculateWhileLoopNestLevel( + instruction_and_logical_time.first); } } @@ -275,7 +290,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, end_logical_time_ = end_time; // Find the earliest time we're allowed to start prefetching. for (current_logical_prefetch_time_ = start_time; - current_logical_prefetch_time_ <= end_logical_time_ && + current_logical_prefetch_time_ < end_logical_time_ && max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ < GetLogicalIntervalElapsed(current_logical_prefetch_time_, end_logical_time_); @@ -290,9 +305,9 @@ int64 CostAnalysisPrefetchIntervalPicker::Next() { } bool CostAnalysisPrefetchIntervalPicker::Done() const { - // The end time is inclusive, so we're done if the prefetch time is greater - // than that. - if (current_logical_prefetch_time_ > end_logical_time_) { + // The end time is exclusive, so we're done if the prefetch time is greater + // than or equal to the end time. + if (current_logical_prefetch_time_ >= end_logical_time_) { return true; } float logical_interval_elapsed = GetLogicalIntervalElapsed( @@ -303,7 +318,17 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const { float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( int64 start_time, int64 end_time) const { - return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]; + int interval_nest_level = + std::min(while_nest_level_[start_time], while_nest_level_[end_time]); + float total_elapsed = 0; + for (int i = start_time + 1; i < end_time; ++i) { + total_elapsed += + elapsed_time_[i] * + tensorflow::MathUtil::IPow( + kWhileExecutionCount, + std::max(0, while_nest_level_[i] - interval_nest_level)); + } + return total_elapsed; } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { @@ -328,7 +353,7 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( absl::optional CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { - return GetMemoryBoundedness(cost_analysis_, interval); + return cost_analysis_.GetMemoryBoundedness(interval); } std::string MemorySpaceAssignment::AllocationValue::ToString() const { @@ -496,13 +521,28 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( return false; } } + + if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart || + position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { + // Disable memory space allocation for these for now. + if (position.index == ShapeIndex({0})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a collective-permute buffer."; + return false; + } else if (position.index == ShapeIndex({1})) { + VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + << " in default mem because it is a collective-permute buffer."; + return false; + } + } } return true; } bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( - const HloUse& use) const { + const AllocationValue& value, const HloUse& use) const { + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); if (use.instruction->opcode() == HloOpcode::kWhile) { HloComputation* while_body = use.instruction->while_body(); @@ -512,7 +552,6 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( HloValue* parameter_value = &alias_analysis_.dataflow_analysis().GetUniqueValueAt( while_body->parameter_instruction(0), use.operand_index); - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); int64 parameter_time = instruction_schedule.at(while_body->parameter_instruction(0)); int64 root_time = instruction_schedule.at(while_body->root_instruction()); @@ -567,7 +606,54 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( "there is a required default memory assignment."; return false; } + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + // For any use of this conditional (the same value might be passed into + // multiple called computations), determine if the parameter->first use + // dependency is short. + int64 conditional_time = instruction_schedule.at(use.instruction); + for (const HloUse& other_use : value.uses()) { + if (other_use.instruction != use.instruction) { + continue; + } + HloComputation* called_computation = + use.instruction->called_computations().at(other_use.operand_number - + 1); + const HloInstruction* parameter_instruction = + called_computation->parameter_instruction(0); + HloValue* parameter_value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt( + parameter_instruction, other_use.operand_index); + int64 parameter_time = instruction_schedule.at(parameter_instruction); + int64 min_use_time = conditional_time; + for (const HloUse& parameter_use : parameter_value->uses()) { + if (parameter_use.instruction->parent() == called_computation && + parameter_use.instruction->opcode() != + HloOpcode::kGetTupleElement && + parameter_use.instruction->opcode() != HloOpcode::kTuple && + parameter_use.instruction->opcode() != HloOpcode::kBitcast) { + min_use_time = std::min( + min_use_time, instruction_schedule.at(parameter_use.instruction)); + } + } + if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( + parameter_value->shape(), parameter_time, min_use_time)) { + VLOG(4) << "Conditional allocation allowed in alternate memory for " + "computation = " + << called_computation->name() + << ", parameter time = " << parameter_time + << ", min use time = " << min_use_time; + return true; + } else { + VLOG(4) << "Conditional allocation not allowed in alternate memory for " + "computation = " + << called_computation->name() + << ", parameter time = " << parameter_time + << ", min use time = " << min_use_time; + } + } + return false; } + return true; } @@ -585,23 +671,35 @@ void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( // definition_time: int. Logical time this value was defined in the schedule. // use_times: string. This is a semicolon-separated list of integers for all // the use times. + // use_names: string. This is a semicolon-separated list of string + // representation of uses. if (debug_str->empty()) { // Append the column names. absl::StrAppend(debug_str, - "buffer_id,buffer_name,alt_mem_benefit,size,definition_" - "time,use_times\n"); + "buffer_id,buffer_name,alt_mem_benefit,size," + "definition_time,use_times,use_names\n"); } const HloBuffer& buffer = alias_analysis_.GetBufferContainingValue(*interval.buffer); const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); int64 definition_time = instruction_schedule.at(interval.buffer->defining_position().instruction); - std::set use_times; + std::vector> uses; for (const HloValue* value : buffer.values()) { for (const HloUse& use : value->uses()) { - use_times.insert(instruction_schedule.at(use.instruction)); + uses.push_back( + {instruction_schedule.at(use.instruction), use.ToString()}); } } + absl::c_sort(uses); + std::vector use_times; + std::vector use_names; + use_times.reserve(uses.size()); + use_names.reserve(uses.size()); + for (auto use : uses) { + use_times.push_back(use.first); + use_names.push_back(use.second); + } absl::StrAppend(debug_str, buffer.id(), ","); absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\","); @@ -612,7 +710,8 @@ void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ","); absl::StrAppend(debug_str, interval.size, ","); absl::StrAppend(debug_str, definition_time, ","); - absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\""); + absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\","); + absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\""); absl::StrAppend(debug_str, "\n"); } @@ -745,8 +844,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - global_max_time_ = instruction_schedule.at( - module->entry_computation()->root_instruction()); // TODO(berkin): For now, place the phi values due to conditionals in // default memory. @@ -756,20 +853,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { if (position.instruction->opcode() == HloOpcode::kConditional) { VLOG(3) << "Adding required assignment for condition output: " << value->ToShortString(); - required_assignments_[value].push_back( - {MemorySpace::kDefault, - instruction_schedule.at(position.instruction), - /*chunk=*/absl::nullopt}); + AddRequiredAssignment(position.instruction, position.index, + MemorySpace::kDefault); for (const HloComputation* called_computation : position.instruction->called_computations()) { - HloValue* root_value = - &alias_analysis_.dataflow_analysis().GetUniqueValueAt( - called_computation->root_instruction(), position.index); - required_assignments_[root_value].push_back( - {MemorySpace::kDefault, - instruction_schedule.at( - called_computation->root_instruction()), - /*chunk=*/absl::nullopt}); + AddRequiredAssignment(called_computation->root_instruction(), + position.index, MemorySpace::kDefault); } } } @@ -795,9 +884,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } // Iterate over the uses. - for (HloUse use : allocation_value.uses()) { + for (int use_idx = 0; use_idx < allocation_value.uses().size(); + ++use_idx) { + const HloUse& use = allocation_value.uses().at(use_idx); int64 use_time = instruction_schedule.at(use.instruction); int64 latest_prefetch_time = use_time; + bool allow_no_copy_alternate_mem_allocation = true; + absl::optional earliest_prefetch_time = absl::nullopt; // Sequential calls include kWhile, kCall, and kConditional opcodes. bool is_sequential_call = @@ -844,14 +937,41 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // when we look at uses within the while loop body. use_time = instruction_schedule.at(while_body->parameter_instruction(0)); + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + // Replace the use time with the earliest parameter of called + // computations. + for (const HloComputation* called_computation : + use.instruction->called_computations()) { + use_time = std::min( + use_time, instruction_schedule.at( + called_computation->parameter_instruction(0))); + } } } // Add a required assignment in default memory if the use not allowed in // alternate memory. - if (!IsUseAllowedInAlternateMemory(use)) { - required_assignments_[allocation_value.value()].push_back( - {MemorySpace::kDefault, use_time, /*chunk=*/absl::nullopt}); + if (!IsUseAllowedInAlternateMemory(allocation_value, use)) { + AddRequiredAssignment(allocation_value.value(), use.instruction, + MemorySpace::kDefault, use_time); + } else if (use_idx > 0) { + // We allow buffers in alternate memory that are passed into + // conditionals to give up their alternate memory allocation inside + // the called computation. This means that if a conditional operator + // has an alternate memory allocation, subsequent uses cannot use the + // same alternate memory allocation in order not to clobber data. So + // we force default memory allocation for these subsequent uses. + const HloUse& previous_use = allocation_value.uses().at(use_idx - 1); + if (previous_use.instruction->opcode() == HloOpcode::kConditional && + previous_use.instruction != use.instruction) { + allow_no_copy_alternate_mem_allocation = false; + earliest_prefetch_time = + instruction_schedule.at(previous_use.instruction); + VLOG(3) << "Previous use (" << previous_use.ToString() + << ") of use (" << use.ToString() + << ") is a conditional, so this use will need to evict. " + << "Earliest prefetch time = " << *earliest_prefetch_time; + } } // Bitcasts don't define buffers and don't directly consume buffers. @@ -859,10 +979,16 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // bitcasts will be handled specially. if (use.instruction->opcode() != HloOpcode::kBitcast) { AllocationRequest request; - request.start_time = definition_time; + // Rarely, (e.g., when conditional true and false parameters are the + // same), definition time can be the time of the conditional and use + // time is the parameter use, which is less. + request.start_time = std::min(definition_time, use_time); request.end_time = use_time; request.latest_prefetch_time = latest_prefetch_time; request.size = interval.size; + request.allow_no_copy_alternate_mem_allocation = + allow_no_copy_alternate_mem_allocation; + request.earliest_prefetch_time = earliest_prefetch_time; request.preferred_offset = preferred_offset; request.use = use; request.allocation_value = &allocation_value; @@ -1048,35 +1174,42 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { chunk = aliased_allocation->chunk(); } - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - HloValue* value = - &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); - int64 instruction_time = instruction_schedule.at(instruction); + AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(), + chunk); +} + +void AlternateMemoryBestFitHeap::AddRequiredAssignment( + const HloValue* value, const HloInstruction* instruction, + MemorySpaceAssignment::MemorySpace memory_space, int64 time, + absl::optional chunk) { // Check for existing required assignment at this time and make sure it is the // same as this if there is one. - auto existing_required_assignment = - RequiredMemoryAssignmentAt(value, instruction_time); + auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); if (existing_required_assignment) { - CHECK(aliased_allocation->memory_space() == - existing_required_assignment->memory_space); + CHECK(memory_space == existing_required_assignment->memory_space) + << "inst = " << instruction->ToString() << " at " << time; CHECK((!chunk && !existing_required_assignment->chunk) || chunk->offset == existing_required_assignment->chunk->offset); - VLOG(3) << "Not adding aliased required assignment because there is one " - "already: " - << value->ToShortString() << " at " << instruction_time << " at " - << (aliased_allocation->memory_space() == MemorySpace::kDefault - ? "def" - : "alt"); - return; + VLOG(3) << "Not adding required assignment because there is one already: " + << value->ToShortString() << " at " << time << " at " + << (memory_space == MemorySpace::kDefault ? "def" : "alt"); + } else { + VLOG(3) << "Adding required assignment: " << value->ToShortString() + << " at " << time << " at " + << (memory_space == MemorySpace::kDefault ? "def" : "alt"); + required_assignments_[value].push_back({memory_space, time, chunk}); } +} - required_assignments_[value].push_back( - {aliased_allocation->memory_space(), instruction_time, chunk}); - VLOG(3) << "Adding aliased required assignment: " << value->ToShortString() - << " at " << instruction_time << " at " - << (aliased_allocation->memory_space() == MemorySpace::kDefault - ? "def" - : "alt"); +void AlternateMemoryBestFitHeap::AddRequiredAssignment( + const HloInstruction* instruction, ShapeIndex index, + MemorySpace memory_space, absl::optional chunk) { + const HloValue* value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); + int64 instruction_time = + hlo_live_range_.instruction_schedule().at(instruction); + AddRequiredAssignment(value, instruction, memory_space, instruction_time, + chunk); } void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { @@ -1174,10 +1307,13 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() { interval_tree_.Remove(interval.start, interval.end, chunk); } for (const auto& interval : pending_async_copies_) { - async_copy_interval_tree_.Remove(interval.start_time, interval.end_time, - kDummyChunk); if (interval.destination == MemorySpace::kAlternate) { + prefetch_interval_tree_.Remove(interval.start_time, interval.end_time, + kDummyChunk); async_copy_ordering_.RemoveCopy(interval); + } else { + eviction_interval_tree_.Remove(interval.start_time, interval.end_time, + kDummyChunk); } } pending_chunks_.clear(); @@ -1276,6 +1412,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( // First try keeping the allocation entirely in the alternate memory. if (required_memory_space_at_start != MemorySpace::kDefault && required_memory_space_at_end != MemorySpace::kDefault && + request.allow_no_copy_alternate_mem_allocation && AllocateInAlternateMemoryNoCopy(request)) { return true; } @@ -1350,6 +1487,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( : "alternate") << " memory between " << start_time << " and " << copy_done_schedule_before_time << " keeping until " << end_time; + CHECK_LT(start_time, copy_done_schedule_before_time); allocations->push_back( absl::make_unique( @@ -1360,27 +1498,37 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( // the limit at any given time. pending_async_copies_.push_back( {start_time, copy_done_schedule_before_time, memory_space}); - async_copy_interval_tree_.Add(start_time, copy_done_schedule_before_time, - kDummyChunk); if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) { + prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time, + kDummyChunk); async_copy_ordering_.AddCopy(pending_async_copies_.back()); + } else { + eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time, + kDummyChunk); } } bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( - int64 start_time, int64 end_time) const { - if (options_.max_outstanding_async_copies < 0) { + int64 start_time, int64 end_time, bool is_prefetch) const { + if (options_.max_outstanding_prefetches < 0 && is_prefetch) { + return false; + } + if (options_.max_outstanding_evictions < 0 && !is_prefetch) { return false; } - // Count the asynchronous copies in the interval tree for the given interval. - int64 num_async_copies = - async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time) - .size(); - - // Add one because we are checking if adding an additional asynchronous copy - // would violate the limit. - return num_async_copies + 1 > options_.max_outstanding_async_copies; + // Count the prefetches/evictions in the interval tree for the given interval. + if (is_prefetch) { + int64 num_prefetches = + prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time) + .size(); + return num_prefetches >= options_.max_outstanding_prefetches; + } else { + int64 num_evictions = + eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time) + .size(); + return num_evictions >= options_.max_outstanding_evictions; + } } bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering( @@ -1512,6 +1660,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { request.allocation_value->defining_position().shape(), eviction_start_time, request.end_time), eviction_end_time); + // Evictions must complete by the time of this use. + preferred_eviction_end_time = + std::min(preferred_eviction_end_time, request.latest_prefetch_time); BufferInterval eviction_mem_interval; eviction_mem_interval.buffer = request.allocation_value->value(); @@ -1519,8 +1670,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { // Try to reserve a buffer from the end of the previous allocation to the // preferred eviction end time. eviction_mem_interval.start = eviction_end_time + 1; - eviction_mem_interval.end = - std::min(preferred_eviction_end_time, global_max_time_); + eviction_mem_interval.end = preferred_eviction_end_time; int64 preferred_offset = prev_allocation->chunk().offset; VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time << ") preferred end time = " << eviction_mem_interval.end; @@ -1542,7 +1692,8 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { bool eviction_interval_too_short = (eviction_start_time == eviction_end_time); bool eviction_violates_outstanding_copies = ViolatesMaximumOutstandingAsyncCopies(eviction_start_time, - eviction_end_time); + eviction_end_time, + /*is_prefetch=*/false); // See if this interval would violate the asynchronous copy limit. if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) { @@ -1563,7 +1714,8 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { bool eviction_scheduled = false; for (int64 time = eviction_start_time; time < eviction_end_time; ++time) { VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")"; - if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1)) { + if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1, + /*is_prefetch=*/false)) { VLOG(3) << "Eviction successful."; AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, /*chunk=*/absl::nullopt, time, time + 1, time + 1, @@ -1605,9 +1757,14 @@ bool AlternateMemoryBestFitHeap::Prefetch( // ^ ^ // Copy Copy // Start Done - options_.prefetch_interval_picker->Begin( - request.use, prev_allocation_in_default_mem.earliest_available_time(), - request.latest_prefetch_time); + int64 earliest_prefetch_time = + prev_allocation_in_default_mem.earliest_available_time(); + if (request.earliest_prefetch_time) { + earliest_prefetch_time = + std::max(earliest_prefetch_time, *request.earliest_prefetch_time); + } + options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time, + request.latest_prefetch_time); VLOG(3) << "Trying prefetch picker = " << options_.prefetch_interval_picker->ToDebugString(); @@ -1618,12 +1775,14 @@ bool AlternateMemoryBestFitHeap::Prefetch( alternate_mem_interval.size = request.size; while (!options_.prefetch_interval_picker->Done()) { alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); + CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time); VLOG(4) << "Trying alternate memory allocation (" << alternate_mem_interval.start << ", " << request.end_time << ")"; // If this additional asynchronous copy would violate the limit, try a // different interval. if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, - request.latest_prefetch_time)) { + request.latest_prefetch_time, + /*is_prefetch=*/true)) { VLOG(4) << "This would violate the outstanding async copy limit."; continue; } @@ -1693,28 +1852,48 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate( return absl::nullopt; } -/*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies( - const HloModule& module) { - int64 max_copies = 0; +StatusOr +MemorySpaceAssignment::CalculateAsyncCopyStats() const { + AsyncCopyStats stats; + stats.max_outstanding_async_copies = 0; + stats.num_prefetches = 0; + stats.prefetch_bytes = 0; + stats.num_evictions = 0; + stats.eviction_bytes = 0; int64 current_copies = 0; - for (HloInstruction* instruction : - module.schedule().sequence(module.entry_computation()).instructions()) { - if (instruction->opcode() == HloOpcode::kCopyStart) { - current_copies++; - } else if (instruction->opcode() == HloOpcode::kCopyDone) { - current_copies--; + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module_)); + for (const HloComputation* computation : + module_->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + current_copies++; + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + current_copies--; + int64 size = + options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction)); + if (instruction->shape().layout().memory_space() == + options_.alternate_memory_space) { + ++stats.num_prefetches; + stats.prefetch_bytes += size; + } else { + ++stats.num_evictions; + stats.eviction_bytes += size; + } + } + stats.max_outstanding_async_copies = + std::max(stats.max_outstanding_async_copies, current_copies); } - max_copies = std::max(max_copies, current_copies); } - return max_copies; + return stats; } /*static*/ MemorySpaceAssignment::BufferIntervalCompare MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis) { return [&](const BufferInterval& x, const BufferInterval& y) { - float x_memory_boundedness = GetMemoryBoundedness(cost_analysis, x); - float y_memory_boundedness = GetMemoryBoundedness(cost_analysis, y); + float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x); + float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y); if (x_memory_boundedness != y_memory_boundedness) { return x_memory_boundedness > y_memory_boundedness; } @@ -1748,6 +1927,8 @@ bool LooksLikeAnActivation(const HloInstruction* inst) { } } break; + case HloOpcode::kBitcast: + return LooksLikeAnActivation(user); default: return true; } @@ -1764,10 +1945,20 @@ bool IsCrossProgramPrefetchCandidate( !value.uses().empty() && options.size_fn(value) <= options.max_size_in_bytes && absl::c_all_of(value.uses(), [&](const HloUse& use) { - const HloInstruction* gte = + const HloInstruction* inst = use.instruction->operand(use.operand_number); - return gte->opcode() == HloOpcode::kGetTupleElement && - !LooksLikeAnActivation(gte); + + // Skip the LooksLikeAnActivation test since we're testing the + // parent GTE and its children below. + if (inst->opcode() == HloOpcode::kBitcast && + inst->operand(0)->opcode() == HloOpcode::kGetTupleElement && + inst->operand(0)->operand(0)->opcode() == + HloOpcode::kParameter) { + return true; + } + + return inst->opcode() == HloOpcode::kGetTupleElement && + !LooksLikeAnActivation(inst); }); } @@ -1838,8 +2029,13 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( VLOG(3) << "Module after memory space assignment: "; XLA_VLOG_LINES(3, module_->ToString()); TF_CHECK_OK(module_->schedule().Verify()); + TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats()); VLOG(1) << "Maximum number of outstanding async copies: " - << CountMaximumOutstandingAsyncCopies(*module_); + << stats.max_outstanding_async_copies; + VLOG(1) << "Number of prefetches: " << stats.num_prefetches + << ", in bytes: " << stats.prefetch_bytes; + VLOG(1) << "Number of evictions: " << stats.num_evictions + << ", in bytes: " << stats.eviction_bytes; TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace()); @@ -2398,6 +2594,34 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { std::tuple> events; + auto add_allocation_and_verify = [&](int64 start_time, int64 end_time, + const Chunk& chunk, + const HloValue* value) { + events[std::make_tuple(start_time, /*is_free=*/false, value->id())] = + std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); + events[std::make_tuple(end_time, /*is_free=*/true, value->id())] = + std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); + + // Get the chunks overlapping in time and search if they overlap in space + // as well. + // TODO(berkin): For now checking against end_time - 1 (exclusive), but we + // really should check against end_time (inclusive) for cases where the + // operand can't share buffer with user (see + // HloDataflowAnalysis::CanShareOperandBufferWithUser). + for (const Chunk& overlapping_chunk : + interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { + if (chunk.OverlapsWith(overlapping_chunk)) { + return InternalError( + ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk" + " off: %d size: %d"), + value->ToShortString(), start_time, end_time, chunk.offset, + chunk.size, overlapping_chunk.offset, overlapping_chunk.size); + } + } + interval_tree.Add(start_time, end_time - 1, chunk); + return Status::OK(); + }; + // Go through all instructions in the module to ensure CopyStart/CopyDone // instructions copy between alternate memory and default memory. for (const HloComputation* computation : @@ -2433,34 +2657,73 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { for (const HloValue* value : buffer.values()) { const HloLiveRange::TimeBound& time_bound = hlo_live_range->buffer_live_ranges().at(value); - events[std::make_tuple(time_bound.start, /*is_free=*/false, - value->id())] = - std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); - events[std::make_tuple(time_bound.end, /*is_free=*/true, value->id())] = - std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); - - VLOG(3) << " buffer: " << buffer.ToString() - << " value: " << value->ToShortString() << ": (" - << time_bound.start << ", " << time_bound.end - << ") off: " << chunk.offset << ", size: " << chunk.size; - // Get the chunks overlapping in time and search if they overlap in space - // as well. - // TODO(berkin): For now checking against end_time - 1 (exclusive), but we - // really should check against end_time (inclusive) for cases where the - // operand can't share buffer with user (see - // HloDataflowAnalysis::CanShareOperandBufferWithUser). - for (const Chunk& overlapping_chunk : - interval_tree.ChunksOverlappingInTime(time_bound.start, - time_bound.end - 1)) { - if (chunk.OverlapsWith(overlapping_chunk)) { - return InternalError( - ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk" - " off: %d size: %d"), - buffer.ToString(), time_bound.start, time_bound.end, chunk.offset, - chunk.size, overlapping_chunk.offset, overlapping_chunk.size); + const HloInstruction* last_use_instruction = nullptr; + int64 last_use_time = time_bound.start; + for (const HloUse& use : value->uses()) { + int64 use_time = + hlo_live_range->instruction_schedule().at(use.instruction); + if (use_time > last_use_time) { + last_use_time = use_time; + last_use_instruction = use.instruction; } } - interval_tree.Add(time_bound.start, time_bound.end - 1, chunk); + + if (last_use_instruction && + last_use_instruction->opcode() == HloOpcode::kConditional) { + // Special case when verifying conditional: we internally split the use + // of alternate memory in conditionals, so fish them out from the + // conditionals. + VLOG(3) << " Splitting conditional buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" + << time_bound.start << ", " << time_bound.end + << ") off: " << chunk.offset << ", size: " << chunk.size; + int64 earliest_computation_start_time = time_bound.end; + for (const HloComputation* called_computation : + last_use_instruction->called_computations()) { + earliest_computation_start_time = + std::min(earliest_computation_start_time, + hlo_live_range->computation_span_times() + .at(called_computation) + .start); + int64 parameter_time = -1; + int64 last_use_time = -1; + for (const HloPosition& position : value->positions()) { + if (position.instruction->opcode() == HloOpcode::kParameter && + position.instruction->parent() == called_computation) { + parameter_time = hlo_live_range->instruction_schedule().at( + position.instruction); + break; + } + } + for (const HloUse& use : value->uses()) { + if (use.instruction->parent() == called_computation) { + last_use_time = std::max( + last_use_time, + hlo_live_range->instruction_schedule().at(use.instruction)); + } + } + if (last_use_time != -1) { + CHECK_NE(parameter_time, -1); + VLOG(3) << " computation: " << called_computation->name() << ": (" + << parameter_time << ", " << last_use_time << ")"; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + parameter_time, last_use_time, chunk, value)); + } + } + VLOG(3) << " from beginning until first computation: (" + << time_bound.start << ", " + << (earliest_computation_start_time - 1) << ")"; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + time_bound.start, earliest_computation_start_time - 1, chunk, + value)); + } else { + VLOG(3) << " buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" + << time_bound.start << ", " << time_bound.end + << ") off: " << chunk.offset << ", size: " << chunk.size; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + time_bound.start, time_bound.end, chunk, value)); + } } } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index eb16db90600..3f59abfd28e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -82,16 +82,31 @@ class MemorySpaceAssignmentCostAnalysis { const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, float alternate_mem_bandwidth_bytes_per_second, - const HloLiveRange& hlo_live_range) + const HloLiveRange& hlo_live_range, const CallGraph& call_graph) : cost_analysis_(cost_analysis), async_copy_bandwidth_bytes_per_second_( async_copy_bandwidth_bytes_per_second), alternate_mem_bandwidth_bytes_per_second_( alternate_mem_bandwidth_bytes_per_second), - hlo_live_range_(hlo_live_range) {} + hlo_live_range_(hlo_live_range), + call_graph_(call_graph) {} const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } + // Returns a heuristic value that captures how much putting this tensor to the + // alternate memory would help if the op is memory bound, or otherwise how far + // off is the op to memory boundedness. The larger this number, the higher + // priority it will be placed in the alternate memory. + float GetAlternateMemoryBenefit( + const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem) const; + + // Returns a heuristic value of memory boundedness for the given + // BufferInterval. The larger this number, the higher priority it will be + // placed in the alternate memory. + float GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const; + // Returns the elapsed time in seconds due to compute only. float GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const; @@ -127,6 +142,10 @@ class MemorySpaceAssignmentCostAnalysis { int64 GetScheduleEndTime() const; + // Returns the number of nested while loop levels this instruction resides in. + // 0 means it is not in a while loop. + int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; + const HloLiveRange& hlo_live_range() const { return hlo_live_range_; } private: @@ -134,6 +153,7 @@ class MemorySpaceAssignmentCostAnalysis { float async_copy_bandwidth_bytes_per_second_; float alternate_mem_bandwidth_bytes_per_second_; const HloLiveRange& hlo_live_range_; + const CallGraph& call_graph_; }; // Abstract base class that memory space assignment uses to pick prefetch @@ -262,10 +282,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { // corresponds to the instruction schedule. float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; - // For performance reasons, we calculate the prefix sum of the elapsed time so - // that it's efficient to find the elapsed time in seconds in any logical - // interval. - std::vector elapsed_time_cumsum_; + // For each instruction in the flattened schedule, maintain their elapsed time + // and while nesting level. + std::vector elapsed_time_; + std::vector while_nest_level_; const MemorySpaceAssignmentCostAnalysis& cost_analysis_; float min_async_copy_to_overlap_ratio_; @@ -323,9 +343,10 @@ class MemorySpaceAssignment { // the opcode) to be placed on the alternate memory. IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn; - // Specifies the upper bound for number of outstanding asynchronous copies, - // -1 for unlimited. - int64 max_outstanding_async_copies = -1; + // Specifies the upper bound for number of outstanding prefetches and + // evictions, -1 for unlimited. + int64 max_outstanding_prefetches = -1; + int64 max_outstanding_evictions = -1; // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). @@ -604,6 +625,15 @@ class MemorySpaceAssignment { AllocationSequence allocation_sequence_; }; + // Statistics of asynchronous copies. + struct AsyncCopyStats { + int64 max_outstanding_async_copies; + int64 num_prefetches; + int64 prefetch_bytes; + int64 num_evictions; + int64 eviction_bytes; + }; + virtual ~MemorySpaceAssignment() = default; // Runs the MemorySpaceAssignment pass. @@ -611,9 +641,8 @@ class MemorySpaceAssignment { HloModule* module, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, const Options& options); - // Returns the maximum number of outstanding asynchronous copies in the - // module. - static int64 CountMaximumOutstandingAsyncCopies(const HloModule& module); + // Calculates asynchronous copy statistics. + StatusOr CalculateAsyncCopyStats() const; static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis); @@ -808,11 +837,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // use_times is a sorted sequence of the times of all uses. // latest_prefetch_time is the latest time we can schedule the CopyDone for a // prefetch. + // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. + // If earliest_prefetch_time is set, prefetches cannot start before this + // value. struct AllocationRequest { int64 start_time; int64 end_time; int64 latest_prefetch_time; int64 size; + bool allow_no_copy_alternate_mem_allocation; + absl::optional earliest_prefetch_time; absl::optional preferred_offset; HloUse use; MemorySpaceAssignment::AllocationValue* allocation_value; @@ -833,7 +867,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const; // Returns true if the use is allowed in the alternate memory. - bool IsUseAllowedInAlternateMemory(const HloUse& use) const; + bool IsUseAllowedInAlternateMemory(const AllocationValue& value, + const HloUse& use) const; // Given an HloValue, creates AllocationValue objects and corresponding // AllocationSequences and appends them into allocation_sequence_list_. @@ -887,6 +922,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const HloInstruction* instruction, ShapeIndex index, const MemorySpaceAssignment::Allocation* aliased_allocation); + // This sets a required assignment. CHECK fails if there is a conflicting + // required assignment at the same time. + void AddRequiredAssignment(const HloValue* value, + const HloInstruction* instruction, + MemorySpace memory_space, int64 time, + absl::optional chunk = absl::nullopt); + void AddRequiredAssignment(const HloInstruction* instruction, + ShapeIndex index, MemorySpace memory_space, + absl::optional chunk = absl::nullopt); + // Adds input and outputs as required assignments. void AddInputAndOutputRequiredAssignments(); @@ -909,8 +954,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // Returns true if the addition of an asynchronous copy in the given time // interval would violate the maximum number of asynchronous copies. - bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, - int64 end_time) const; + bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, int64 end_time, + bool is_prefetch) const; // Return true if the asynchronous copy would violate the pipelining order. bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const; @@ -953,8 +998,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const HloAliasAnalysis& alias_analysis_; const HloLiveRange& hlo_live_range_; // We use a interval tree to keep track of the number of outstanding - // asynchronous copies. - BufferIntervalTree async_copy_interval_tree_; + // prefetches and evictions. + BufferIntervalTree prefetch_interval_tree_; + BufferIntervalTree eviction_interval_tree_; AsynchronousCopyOrdering async_copy_ordering_; std::vector> pending_chunks_; std::vector pending_async_copies_; @@ -964,7 +1010,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; - int64 global_max_time_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index b2125d318d0..23b311730f8 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -57,9 +57,10 @@ class MemorySpaceAssignmentTest : public HloTestBase, HloLiveRange::Run(module->schedule(), *alias_analysis, module->entry_computation()) .ValueOrDie(); + std::unique_ptr call_graph = CallGraph::Build(module); MemorySpaceAssignmentCostAnalysis cost_analysis( hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth, - *hlo_live_range); + *hlo_live_range, *call_graph); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, @@ -126,7 +127,8 @@ class MemorySpaceAssignmentTest : public HloTestBase, options.prefetch_interval_picker = prefetch_interval_picker; options.size_fn = size_fn; options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; - options.max_outstanding_async_copies = max_outstanding_async_copies; + options.max_outstanding_prefetches = max_outstanding_async_copies; + options.max_outstanding_evictions = max_outstanding_async_copies; options.allocate_across_sequential_calls = GetParam(); options.verify = true; @@ -184,6 +186,47 @@ class MemorySpaceAssignmentTest : public HloTestBase, } } + struct OutstandingAsyncCopies { + int64 max_copies; + int64 max_prefetches; + int64 max_evictions; + }; + + /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies( + const HloModule& module) { + OutstandingAsyncCopies copies{0, 0, 0}; + int64 current_copies = 0; + int64 current_prefetches = 0; + int64 current_evictions = 0; + for (HloInstruction* instruction : module.schedule() + .sequence(module.entry_computation()) + .instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + current_copies++; + if (ShapeUtil::GetSubshape(instruction->shape(), {0}) + .layout() + .memory_space() == kAlternateMemorySpace) { + current_prefetches++; + } else { + current_evictions++; + } + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + current_copies--; + if (instruction->shape().layout().memory_space() == + kAlternateMemorySpace) { + current_prefetches--; + } else { + current_evictions--; + } + } + copies.max_copies = std::max(copies.max_copies, current_copies); + copies.max_prefetches = + std::max(copies.max_prefetches, current_prefetches); + copies.max_prefetches = std::max(copies.max_evictions, current_evictions); + } + return copies; + } + std::unique_ptr CreateEvictAndPrefetchModule() { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -391,8 +434,8 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 0); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 0); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 0); } TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { @@ -400,8 +443,8 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 1); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 1); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 1); } TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { @@ -409,8 +452,8 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 2); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 2); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 2); } // TODO(berkin): This test is broken with some prefetch timing improvements. @@ -1650,6 +1693,324 @@ TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) { AssignMemorySpace(module.get()); } +TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) { + // Checks if simple conditionals get alternate memory allocations. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Check that copy and gtes got alternate memory allocations. + auto copy = + module->GetComputationWithName("entry")->GetInstructionWithName("copy"); + EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); + auto neg1 = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("neg1"); + auto neg1_operand = neg1->operand(0); + EXPECT_EQ(neg1_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + auto neg2 = module->GetComputationWithName("false_computation") + ->GetInstructionWithName("neg2"); + auto neg2_operand = neg2->operand(0); + EXPECT_EQ(neg2_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) { + // Checks if we avoid unnecessary allocation in alternate memory if the input + // won't be used in the computation for a long time. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}, f32[3]{0}) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + neg0 = f32[3]{0} negate(gte0) + neg1 = f32[3]{0} negate(neg0) + neg2 = f32[3]{0} negate(neg1) + neg3 = f32[3]{0} negate(neg2) + neg4 = f32[3]{0} negate(neg3) + neg5 = f32[3]{0} negate(neg4) + neg6 = f32[3]{0} negate(neg5) + neg7 = f32[3]{0} negate(neg6) + neg8 = f32[3]{0} negate(neg7) + neg9 = f32[3]{0} negate(neg8) + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + ROOT add = f32[3]{0} add(neg9, gte1) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) + tuple1 = (f32[3]{0}) tuple(copy0) + ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Check that copy1 doesn't get unnecessarily allocated in alternate mem + // (due to long negate chain in true_computation) but is prefetched before + // add. + auto copy0 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy0"); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); + auto copy1 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy1"); + EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace); + auto add = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("add"); + auto add_operand = add->operand(1); + EXPECT_EQ(add_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) { + // Make sure there is an evict when there is a conditional use followed by + // another use. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}, f32[3]{0}) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + add0 = f32[3]{0} add(gte0, gte1) + neg0 = f32[3]{0} negate(add0) + neg1 = f32[3]{0} negate(neg0) + neg2 = f32[3]{0} negate(neg1) + neg3 = f32[3]{0} negate(neg2) + neg4 = f32[3]{0} negate(neg3) + neg5 = f32[3]{0} negate(neg4) + neg6 = f32[3]{0} negate(neg5) + neg7 = f32[3]{0} negate(neg6) + neg8 = f32[3]{0} negate(neg7) + ROOT neg9 = f32[3]{0} negate(neg8) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) + tuple1 = (f32[3]{0}) tuple(copy0) + conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation + ROOT add1 = f32[3]{0} add(copy1, conditional) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure the copy1->add edge is in alternate memory. Before conditional, + // this should be evicted to default memory and neg uses the input from + // default memory. + auto copy1 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy1"); + EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace); + auto add0 = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("add0"); + auto add0_operand = add0->operand(1); + EXPECT_EQ(add0_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + auto add1 = + module->GetComputationWithName("entry")->GetInstructionWithName("add1"); + auto add1_operand = add1->operand(0); + EXPECT_EQ(add1_operand->shape().layout().memory_space(), + kDefaultMemorySpace); + EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) { + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + while_cond { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=2 + } + + while_body { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + gte2 = pred[] get-tuple-element(p0), index=2 + cond_tuple = (f32[3]{0}) tuple(gte0) + conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation + add = f32[3]{0} add(conditional, gte1) + neg0 = f32[3]{0} negate(add) + neg1 = f32[3]{0} negate(neg0) + ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1) + while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body + ROOT gte = f32[3]{0} get-tuple-element(while), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation. + // This will force an eviction and a prefetch for while body root. + auto copy0 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy0"); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); + auto conditional = module->GetComputationWithName("while_body") + ->GetInstructionWithName("conditional"); + auto conditional_operand = conditional->operand(1); + EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0}) + .layout() + .memory_space(), + kAlternateMemorySpace); + auto while_root = + module->GetComputationWithName("while_body")->root_instruction(); + auto while_root_operand = while_root->operand(0); + EXPECT_THAT( + while_root_operand, + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace, + op::GetTupleElement(op::Parameter(0))))); + } +} + +TEST_P(MemorySpaceAssignmentTest, NestedConditional) { + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + true_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + slice = f32[1]{0} slice(gte), slice={[0:1]} + bitcast = f32[] bitcast(slice) + constant = f32[] constant(0.0) + compare = pred[] compare(bitcast, constant), direction=GT + ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2 + } + + false_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg3 = f32[3]{0} negate(gte) + } + + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure alternate memory allocation gets propagated into both levels of + // conditional. + auto copy = + module->GetComputationWithName("entry")->GetInstructionWithName("copy"); + EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); + auto neg1_operand = module->GetComputationWithName("true_computation2") + ->GetInstructionWithName("neg1") + ->operand(0); + auto neg2_operand = module->GetComputationWithName("false_computation2") + ->GetInstructionWithName("neg2") + ->operand(0); + auto neg3_operand = module->GetComputationWithName("false_computation1") + ->GetInstructionWithName("neg3") + ->operand(0); + EXPECT_EQ(neg1_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_EQ(neg2_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_EQ(neg3_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + TEST_P(MemorySpaceAssignmentTest, RequestIdentifierShouldNotBeAllocatedInAlternateMem) { // Ensure that request identifier returned by Send/Recv HLOs are not allocated @@ -2136,7 +2497,8 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) { AssignMemorySpace(module.get(), -1, 5); } -TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) { +// TODO(berkin): This might be an incorrect input graph, investigate. +TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); @@ -3428,6 +3790,52 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) { } } +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kBatch = 8; + constexpr int kFeature = 8; + constexpr int kOutput = 2; + + auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); + auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature}); + auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput}); + auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); + auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); + + auto bitcast = + builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, lhs, rhs, bitcast, dot}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + if (!cross_program_prefetches.empty()) { + EXPECT_EQ(cross_program_prefetches[0].first, 0); + EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); + } +} + TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.cc b/tensorflow/compiler/xla/service/memory_space_propagation.cc new file mode 100644 index 00000000000..80eb4017477 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation.cc @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_propagation.h" + +namespace xla { + +StatusOr MemorySpacePropagation::Run(HloModule* module) { + bool modified = false; + TF_ASSIGN_OR_RETURN(auto dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + dataflow_analysis_ = std::move(dataflow_analysis); + + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kFusion) { + // Propagate the operand subshapes. + for (int operand_idx = 0; operand_idx < instruction->operand_count(); + ++operand_idx) { + modified |= + PropagateSubshapes(instruction->operand(operand_idx)->shape(), + instruction->fused_parameter(operand_idx)); + } + + // Propagate output subshapes. + modified |= PropagateSubshapes(instruction->shape(), + instruction->fused_expression_root()); + } + } + } + return modified; +} + +bool MemorySpacePropagation::PropagateSubshapes( + const Shape& caller_shape, const HloInstruction* callee_instruction) const { + bool modified = false; + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(caller_shape)) { + int64 memory_space = indexed_shape.shape.layout().memory_space(); + const HloValue& value = dataflow_analysis_->GetUniqueValueAt( + callee_instruction, indexed_shape.index); + + for (const HloPosition& position : value.positions()) { + Shape* shape = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + if (shape->layout().memory_space() != memory_space) { + shape->mutable_layout()->set_memory_space(memory_space); + modified = true; + } + } + } + return modified; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.h b/tensorflow/compiler/xla/service/memory_space_propagation.h new file mode 100644 index 00000000000..65a1dfd14a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This is a legalization pass that propagates the memory space in the layout to +// the fusion computations. +class MemorySpacePropagation : public HloModulePass { + public: + ~MemorySpacePropagation() override = default; + absl::string_view name() const override { return "memory-space-propagation"; } + StatusOr Run(HloModule* module) override; + + private: + // Given the caller shape (operand or output) and its corresponding + // insturction in the fused computation (parameter or root), propagates the + // memory space to all the subshapes in the callee side. Returns true if the + // module is modified. + bool PropagateSubshapes(const Shape& caller_shape, + const HloInstruction* callee_instruction) const; + + std::unique_ptr dataflow_analysis_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc new file mode 100644 index 00000000000..8d74958f6aa --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_propagation.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class MemorySpacePropagationTest : public HloTestBase { + public: + MemorySpacePropagationTest() + : HloTestBase(), + verifier_(/*layout_sensitive=*/false, /*allow_mixed_precision*/ false) { + } + + Status Verify(HloModule* module) { return verifier_.Run(module).status(); } + + private: + HloVerifier verifier_; +}; + +TEST_F(MemorySpacePropagationTest, NoMemorySpace) { + absl::string_view hlo_string = R"( + HloModule NoMemorySpace + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)} copy(%param2) + %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_FALSE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_ASSERT_OK_AND_ASSIGN(auto ref, ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, NonTupleOutput) { + absl::string_view hlo_string = R"( + HloModule NonTupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule NonTupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, TupleOutput) { + absl::string_view hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + %gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0 + %gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1 + ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + %gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0 + %gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1 + ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index cd679f7412e..07655a61074 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -185,11 +185,11 @@ cc_library( "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LoopOps", - "@llvm-project//mlir:LoopOpsTransforms", - "@llvm-project//mlir:LoopsToGPUPass", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToGPUPass", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index 56684b1f726..d5cad385324 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "llvm/Support/raw_ostream.h" -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 33d3690d4ab..4645b084eb6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -19,8 +19,8 @@ limitations under the License. #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project -#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project @@ -31,9 +31,9 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/Passes.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/Transforms.h" // from @llvm-project +#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project @@ -45,6 +45,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -60,34 +61,6 @@ namespace { using ::mlir::xla_lhlo::FusionOp; -// Following are some small transformations that are required to clean up code -// after lowering from linalg to loops. - -// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that -// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp. -// This is needed, as these ops are not closed from above and hence nested pass -// managers can not be applied. -struct NestedHloRegionsConverter - : public mlir::PassWrapper { - void runOnFunction() override { - auto& ctx = getContext(); - mlir::OwningRewritePatternList patterns; - mlir::ConversionTarget target(ctx); - target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>(); - ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns); - - getFunction().walk([&](mlir::Operation* op) { - if (op->getNumRegions() == 0) { - return; - } - if (failed(applyPartialConversion(op, target, patterns, nullptr))) { - signalPassFailure(); - } - }); - } -}; - // Replaces a FusionOp by the operations contained in its region. struct FusionOpRemover : public mlir::PassWrapper { @@ -132,7 +105,7 @@ struct StoreForwardingPass // No store operation found. Continue search outside of the parallel // loop if block is in a parallel loop. if (auto parallelOp = - llvm::dyn_cast(block->getParentOp())) { + llvm::dyn_cast(block->getParentOp())) { return findStore(parallelOp.getOperation(), matches); } return {}; @@ -378,7 +351,7 @@ struct FixKernelFunctionSignatures struct MapParallelLoops : public mlir::PassWrapper { void runOnFunction() override { - mlir::greedilyMapParallelLoopsToGPU(getFunction().getBody()); + mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); } }; @@ -388,8 +361,8 @@ struct MapParallelLoops struct FuseInnerParallelLoops : public mlir::PassWrapper { void runOnFunction() override { - getFunction().walk([](mlir::loop::ParallelOp op) { - mlir::loop::naivelyFuseParallelOps(op.region()); + getFunction().walk([](mlir::scf::ParallelOp op) { + mlir::scf::naivelyFuseParallelOps(op.region()); }); } }; @@ -401,7 +374,7 @@ struct ParallelLoopCollapsingToFirstDim void runOnOperation() override { mlir::Operation* module = getOperation(); - module->walk([&](mlir::loop::ParallelOp op) { + module->walk([&](mlir::scf::ParallelOp op) { unsigned num_loops = op.getNumLoops(); std::vector combinedLoops; combinedLoops.reserve(num_loops); @@ -436,8 +409,10 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end()); } - // First, lower bodies of LHLO operations that contain HLO ops. - pm.addPass(absl::make_unique()); + // Legalize from HLO to LHLO. + pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass()); + // Moving `AllocOp`s and inserting missing `DeallocOp`s + pm.addPass(::mlir::createBufferPlacementPass()); // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); // Remove unnecessary LHLO copies. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ab71c30dcae..2ed5e709d81 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -313,6 +313,8 @@ StatusOr> Service::CreateModuleConfig( if (execution_options->num_partitions() > 0) { config->set_num_partitions(execution_options->num_partitions()); } + config->set_use_spmd_partitioning( + execution_options->use_spmd_partitioning()); config->set_seed(execution_options->seed()); config->set_launch_id(execution_options->launch_id()); config->set_debug_options(execution_options->debug_options()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index f3c8eec1751..75a80747c1d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -1856,7 +1857,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (fft_type) { case FFT: case IFFT: - if (in.element_type() != C64) { + if (!primitive_util::IsComplexType(in.element_type())) { return InvalidArgument("%s requires complex input type, found %s.", FftType_Name(fft_type), PrimitiveType_Name(in.element_type())); @@ -1864,8 +1865,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, RET_CHECK_RANK(in); return in; case RFFT: { - if (in.element_type() != F32) { - return InvalidArgument("RFFT requires F32 input type, found %s.", + if (in.element_type() != F32 && in.element_type() != F64) { + return InvalidArgument("RFFT requires F32 or F64 input type, found %s.", PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); @@ -1880,7 +1881,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]); } } - Shape result = ShapeUtil::ChangeElementType(in, C64); + Shape result = ShapeUtil::ChangeElementType( + in, in.element_type() == F32 ? C64 : C128); // Preserve the size of zero-sized dimensions. if (fft_length[fft_rank - 1] != 0) { result.set_dimensions(result.dimensions_size() - 1, @@ -1889,8 +1891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return result; } case IRFFT: { - if (in.element_type() != C64) { - return InvalidArgument("IRFFT requires C64 input type, found %s.", + if (!primitive_util::IsComplexType(in.element_type())) { + return InvalidArgument("IRFFT requires complex input type, found %s.", PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); @@ -1999,6 +2001,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return a; } +/* static */ StatusOr ShapeInference::InferAllGatherShape( + const Shape& operand_shape, int64 all_gather_dimension, int64 shard_count) { + TF_RET_CHECK(all_gather_dimension >= 0); + TF_RET_CHECK(all_gather_dimension < operand_shape.rank()); + TF_RET_CHECK(shard_count > 0); + auto shape = operand_shape; + shape.set_dimensions(all_gather_dimension, + shard_count * shape.dimensions(all_gather_dimension)); + return shape; +} + /* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 2e96a77aa22..2cb5930d098 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -123,6 +123,12 @@ class ShapeInference { // Infers the shape produced by the given triangular solve operation. static StatusOr InferCholeskyShape(const Shape& a); + // Infers the shape produced by an all-gather with the given operand shape, + // concat dimension, and shard count. + static StatusOr InferAllGatherShape(const Shape& operand_shape, + int64 all_gather_dimension, + int64 shard_count); + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferAllReduceShape( diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 448f5119546..b5ecf6e583e 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -615,8 +615,7 @@ namespace fft { static const char* unsupported_rank = "only supports ranks 1-3"; static const char* invalid_rank = "requires input of at least same rank"; static const char* requires_complex_input = "requires complex input type"; -static const char* requires_f32_input = "requires F32 input type"; -static const char* requires_c64_input = "requires C64 input type"; +static const char* requires_f32_input = "requires F32 or F64 input type"; static const char* dimensions_match = "innermost dimensions match fft_length"; static const char* innermost_dimension_matches = "innermost dimension matches fft_length/2+1"; @@ -654,7 +653,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) { Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); - fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input); + fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) { @@ -672,7 +671,7 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) { Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); - fft::Fail(shape_c128, type, {16, 8}, fft::requires_complex_input); + fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) { @@ -747,9 +746,10 @@ TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) { TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) { FftType type = FftType::IRFFT; Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); - Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); - fft::Fail(shape_f32, type, {16, 8}, fft::requires_c64_input); - fft::Fail(shape_c128, type, {16, 8}, fft::requires_c64_input); + Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5}); + Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8}); + fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); + fft::Pass(shape_c128, type, {16, 8}, shape_f64_out); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) { diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index a1872330648..995b0ece7cd 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -93,6 +94,18 @@ class ShapedBuffer { buffers_.replace_shape_ptr(&on_device_shape_); } + // Reset the shape of this shaped buffer and underlying buffer structure. + // + // Precondition: EqualStructure(this->on_device_shape_, on_device_shape). + void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) + << "Structures are not the same. new: " << on_device_shape + << ", old: " << on_device_shape_; + on_host_shape_ = on_host_shape; + on_device_shape_ = on_device_shape; + buffers_.replace_shape_ptr(&on_device_shape_); + } + // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. const ShapeTree& buffers() const { return buffers_; } @@ -124,9 +137,8 @@ class ShapedBuffer { std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); -// ShapedBuffer derived class which allocates all internal buffers on -// construction and deallocates the memory when the object is -// destructed. +// ScopedShapedBuffer takes allocated buffers as inputs, and deallocates on +// destruction. This class represents an owning wrapper around `ShapedBuffer`. // // TODO(timshen): Remove inheritance between ScopedShapedBuffer and // ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc new file mode 100644 index 00000000000..c6990e76c95 --- /dev/null +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -0,0 +1,1491 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sharding_propagation.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_split.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ComputationMap = + absl::flat_hash_map; + +// Returns true iff the specified hlo or sharding has a spatially partitioned +// sharding (tiled or replicated) what can be propagated by sharding +// propagation. +bool IsSpatiallyPartitioned(const HloSharding& sharding) { + if (sharding.IsTuple()) { + return absl::c_any_of(sharding.tuple_elements(), IsSpatiallyPartitioned); + } else { + return !sharding.IsTileMaximal() || sharding.IsReplicated(); + } +} +bool IsSpatiallyPartitioned(const HloInstruction* hlo) { + return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding()); +} + +// Returns true if the lhs sharding is preferable over the rhs sharding. +// The most specific sharding is tile maximal followed by single device tile +// maximal and finally replicated. This order aims to primarily reduce memory +// usage and secondly reduce total compute. +// Note: This does NOT provide a total ordering as we can have 2 different +// sharding with same preference level. +bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { + CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()); + if (lhs.IsTuple()) { + // For tuples we consider lhs to have a better sharding if none of the + // elements are worse and at least one element is better then in rhs + // sharding. + const auto& lhs_shardings = lhs.tuple_elements(); + const auto& rhs_shardings = rhs.tuple_elements(); + CHECK_EQ(lhs_shardings.size(), rhs_shardings.size()); + bool is_better = false; + for (int64 i = 0; i < lhs_shardings.size(); ++i) { + if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) { + return false; + } + if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) { + is_better = true; + } + } + return is_better; + } + if (!rhs.IsTileMaximal()) { + // If we already have a non-tile-maximal sharding then we can't improve + // that. + return false; + } else if (!rhs.IsReplicated()) { + // If we are not replicated then only tiled (not tile maximal) shardings + // can improve us. + return !lhs.IsTileMaximal(); + } else { + // If we are replicated then any non-replicated sharding can improve us. + return !lhs.IsReplicated(); + } +} + +// Returns a sharding where each tuple element is chosen as the more specific +// one of the corresponding elements in a and b. Requires a an b to have the +// same tuple nesting. +HloSharding MergeForMoreSpecificSharding(const HloSharding& a, + const HloSharding& b) { + if (a.IsTuple()) { + HloSharding result = a; + CHECK(b.IsTuple()); + CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size()); + for (int64 i = 0; i < result.tuple_elements().size(); ++i) { + result.tuple_elements()[i] = MergeForMoreSpecificSharding( + a.tuple_elements()[i], b.tuple_elements()[i]); + } + return result; + } + return IsShardingMoreSpecific(a, b) ? a : b; +} + +// Updates the sharding of the specified instruction with the specified sharding +// if it is better than the current one and returns true if a new sharding have +// been applied. +bool MaybeImproveInstructionSharding(const HloSharding& sharding, + HloInstruction* instruction) { + // We don't want to propagate tile maximal shardings. + if (!IsSpatiallyPartitioned(sharding)) { + return false; + } + // Any sharding is better then no sharding. + if (!instruction->has_sharding()) { + instruction->set_sharding(sharding); + return true; + } + if (IsShardingMoreSpecific(sharding, instruction->sharding())) { + instruction->set_sharding(sharding); + return true; + } + return false; +} + +// Sets the sharding for every element within a tuple to replicated (default +// sharding). This is necessary because there is no way to represent a tuple +// sharding when only some of the elements are sharded. +void SetDefaultTupleSharding(HloInstruction* instruction) { + instruction->set_sharding( + HloSharding::SingleTuple(instruction->shape(), HloSharding::Replicate())); +} + +// We consider a convolution kernel to be small iff it is smaller along all +// spatial dimensions then the output of the convolution. The rational is that +// we can either shard the kernel or the output and we want to shard the larger +// one for better efficiency. +bool IsConvolutionKernelSmall(const HloInstruction* instruction) { + CHECK_EQ(instruction->opcode(), HloOpcode::kConvolution); + const HloInstruction* rhs = instruction->operand(1); + const auto& dnums = instruction->convolution_dimension_numbers(); + for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) { + int64 kernel_dim = + rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)); + int64 output_dim = + instruction->shape().dimensions(dnums.output_spatial_dimensions(i)); + if (kernel_dim >= output_dim) { + return false; + } + } + return true; +} + +// Return the operand which is the most suitable for determining the sharding +// for the specified instruction or nullptr if there isn't any suitable operand. +const HloInstruction* PickRepresentativeOperand( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + // For these opcodes the output sharding has to be determined by the + // sharding of the first operand but we can only determine sharding based + // on it if it already has a sharding. + if (instruction->operand(0)->has_sharding()) { + return instruction->operand(0); + } + return nullptr; + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kCompare: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kAllGather: + case HloOpcode::kAllReduce: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kDivide: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kRemainder: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: + case HloOpcode::kSelect: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSort: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + case HloOpcode::kXor: { + // For these opcodes the output sharding can be determined by any operand + // so we find the operand with the most specific sharding. + const HloInstruction* best_operand = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + if (operand->has_sharding() && + (best_operand == nullptr || + IsShardingMoreSpecific(operand->sharding(), + best_operand->sharding()))) { + best_operand = operand; + } + } + return best_operand; + } + + // There is no suitable operand for the rest of the opcodes. + case HloOpcode::kAddDependency: + case HloOpcode::kAfterAll: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kCholesky: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kCopyDone: + case HloOpcode::kCopyStart: + case HloOpcode::kCustomCall: + case HloOpcode::kDomain: + case HloOpcode::kDot: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kFft: + case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kPartitionId: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kReplicaId: + case HloOpcode::kReshape: + case HloOpcode::kRng: + case HloOpcode::kRngGetAndUpdateState: + case HloOpcode::kRngBitGenerator: + case HloOpcode::kScatter: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTrace: + case HloOpcode::kTranspose: + case HloOpcode::kTriangularSolve: + case HloOpcode::kTuple: + case HloOpcode::kGetDimensionSize: + case HloOpcode::kSetDimensionSize: + return nullptr; + } +} + +bool SupportSpatialPartitioning(const HloInstruction* instruction, + const ComputationMap& computation_map, + bool is_spmd) { + if (instruction->parent()->root_instruction() == instruction && + computation_map.find(instruction->parent()) == computation_map.end()) { + // We don't support sharding the root instruction of a computation yet, + // unless the computation is a while body. + return false; + } + + if (instruction->IsElementwise() && + (instruction->opcode() != HloOpcode::kRng || is_spmd)) { + return true; + } + switch (instruction->opcode()) { + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kDot: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kPad: + case HloOpcode::kReduceWindow: + case HloOpcode::kReshape: + case HloOpcode::kScatter: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + case HloOpcode::kReduce: + return true; + case HloOpcode::kAllReduce: + // Only if channel_id is not specified. + return instruction->channel_id() == absl::nullopt; + case HloOpcode::kParameter: + return computation_map.find(instruction->parent()) != + computation_map.end(); + case HloOpcode::kReverse: + return is_spmd; + default: + return false; + } +} + +// Tries to update the sharding of the specified instruction based on its +// operands and returns true if the sharding of the instruction have been +// changed and false otherwise. +bool InferShardingFromOperands(HloInstruction* instruction, + const ComputationMap& computation_map, + bool is_spmd, bool aggressive_prop) { + if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { + // If an array shaped HLO doesn't support spatial partitioning but at least + // one of its operand is replicated then we make the HLO replicated as well. + if (instruction->shape().IsTuple() || instruction->operand_count() == 0 || + instruction == instruction->parent()->root_instruction() || + instruction->HasSideEffect()) { + return false; + } + if (absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { + return op->has_sharding() && op->sharding().IsReplicated(); + })) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + return false; + } + + switch (instruction->opcode()) { + case HloOpcode::kGetTupleElement: { + const HloInstruction* operand = instruction->operand(0); + if (!IsSpatiallyPartitioned(operand)) { + return false; + } + HloSharding new_sharding = operand->sharding().GetSubSharding( + operand->shape(), {instruction->tuple_index()}); + return MaybeImproveInstructionSharding(new_sharding, instruction); + } + case HloOpcode::kTuple: { + if (absl::c_none_of(instruction->operands(), + [](const HloInstruction* hlo) { + return IsSpatiallyPartitioned(hlo); + })) { + // None of the operands have a spatially partitioned sharding. + return false; + } + bool changed = false; + if (!instruction->has_sharding()) { + // Set the sharding for all elements in the tuple because it isn't + // possible to set a partial sharding. + SetDefaultTupleSharding(instruction); + changed = true; + } + // Go through each operand and if the operand has a sharding that is + // better than the current sharding for that tuple element then update + // it. + const Shape& shape = instruction->shape(); + std::vector sub_shardings = + instruction->sharding().tuple_elements(); + int64 sub_sharding_index = 0; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const HloInstruction* operand = instruction->operand(i); + if (operand->has_sharding()) { + if (operand->shape().IsTuple()) { + for (int64 i = 0, e = ShapeUtil::GetLeafCount(operand->shape()); + i < e; ++i) { + if (IsShardingMoreSpecific( + operand->sharding().tuple_elements()[i], + sub_shardings[sub_sharding_index + i])) { + sub_shardings[sub_sharding_index + i] = + operand->sharding().tuple_elements()[i]; + } + } + } else { + if (IsShardingMoreSpecific(operand->sharding(), + sub_shardings[sub_sharding_index])) { + sub_shardings[sub_sharding_index] = operand->sharding(); + } + } + } + sub_sharding_index += ShapeUtil::GetLeafCount(operand->shape()); + } + + HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings); + if (new_sharding != instruction->sharding()) { + instruction->set_sharding(new_sharding); + return true; + } + return changed; + } + case HloOpcode::kReduce: { + // Reduce could have a tuple shape, where the first half of operands are + // the arrays to reduce, and the second half of operands are the init + // values. + bool changed = false; + for (int64 operand_id = 0; operand_id < instruction->operand_count() / 2; + ++operand_id) { + const HloInstruction* operand = instruction->operand(operand_id); + if (!IsSpatiallyPartitioned(operand)) { + continue; + } + auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) { + if (instruction->operand_count() == 2) { + return sharding; + } + std::vector tuple(instruction->operand_count() / 2, + sharding); + return HloSharding::Tuple(instruction->shape(), tuple); + }; + if (operand->sharding().IsReplicated()) { + changed |= MaybeImproveInstructionSharding( + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); + continue; + } + if (absl::c_any_of(instruction->dimensions(), [operand](int64 dim) { + return operand->sharding().tile_assignment().dim(dim) > 1; + })) { + // We are reducing along one of the sharded dimensions. We don't + // support tiled sharding in this case. + changed |= MaybeImproveInstructionSharding( + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); + } else { + // We are reducing along some of the non-sharded dimensions. The + // result sharding should be the same as the operand sharding with the + // reduction dimensions removed as they are removed from the result + // shape. + std::vector target_tile_assignment_dimensions; + const auto& dimensions = instruction->dimensions(); + for (int64 i = 0; i < operand->shape().rank(); ++i) { + if (absl::c_find(dimensions, i) == dimensions.end()) { + target_tile_assignment_dimensions.push_back( + operand->sharding().tile_assignment().dim(i)); + } + } + Array new_tile_assignment = + operand->sharding().tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + // Use the same sharding for all tuple elements, because they are part + // of the same reduce instruction. + HloSharding new_sharding = + get_maybe_tuple_sharding(HloSharding::Tile(new_tile_assignment)); + changed |= MaybeImproveInstructionSharding(new_sharding, instruction); + } + } + return changed; + } + case HloOpcode::kBroadcast: { + const HloInstruction* op = instruction->operand(0); + if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { + return false; + } + // Heuristic: If an operand is more than 8 times fewer elements than its + // output, do not propagate sharding. + if (ShapeUtil::ElementsIn(instruction->shape()) > + 8 * ShapeUtil::ElementsIn(op->shape())) { + return false; + } + // The output will be tiled along the broadcasted dimension the same way + // as the input for the broadcast while the other dimensions are kept + // non-tiled. + std::vector target_tile_assignment_dimensions; + const auto& dimensions = instruction->dimensions(); + for (int64 i = 0; i < instruction->shape().rank(); ++i) { + auto it = absl::c_find(dimensions, i); + if (it == dimensions.end()) { + target_tile_assignment_dimensions.push_back(1); + } else { + const int64 source_dim = std::distance(dimensions.begin(), it); + target_tile_assignment_dimensions.push_back( + op->sharding().tile_assignment().dim(source_dim)); + } + } + Array new_tile_assignment = op->sharding().tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + HloSharding new_sharding = HloSharding::Tile(new_tile_assignment); + return MaybeImproveInstructionSharding(new_sharding, instruction); + } + case HloOpcode::kConvolution: { + const auto& dnums = instruction->convolution_dimension_numbers(); + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + auto get_tiled_sharding_based_on_lhs = [&] { + CHECK(!lhs->sharding().IsTileMaximal()); + std::vector output_to_lhs_indices(instruction->shape().rank()); + output_to_lhs_indices[dnums.output_batch_dimension()] = + dnums.input_batch_dimension(); + output_to_lhs_indices[dnums.output_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + output_to_lhs_indices[dnums.output_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(lhs->sharding(), + output_to_lhs_indices); + }; + auto get_tiled_sharding_based_on_rhs = [&] { + CHECK(!rhs->sharding().IsTileMaximal()); + std::vector output_to_rhs_indices(instruction->shape().rank()); + output_to_rhs_indices[dnums.output_batch_dimension()] = + dnums.kernel_input_feature_dimension(); + output_to_rhs_indices[dnums.output_feature_dimension()] = + dnums.kernel_output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + output_to_rhs_indices[dnums.output_spatial_dimensions(i)] = + dnums.kernel_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(rhs->sharding(), + output_to_rhs_indices); + }; + if (auto dot_dims = + dot_as_convolution_util::ParseDotGeneralFromConvolution( + instruction)) { + // lhs_or_rhs: lhs is 0 and rhs is 1. + auto partitioned_only_along = + [&](const HloSharding& sharding, + std::vector& dims, + int64 lhs_or_rhs) { + if (sharding.IsTileMaximal()) { + return false; + } + int64 partition_count = 1; + for (const auto& dim : dims) { + if (lhs_or_rhs == 0) { + partition_count *= sharding.tile_assignment().dim(dim.lhs); + } else { + CHECK_EQ(lhs_or_rhs, 1); + partition_count *= sharding.tile_assignment().dim(dim.rhs); + } + } + return partition_count == + sharding.tile_assignment().num_elements(); + }; + // If LHS/RHS is partitioned only along the batch dimensions, propagate + // the sharding to the output, since batch dimensions are the easiest to + // partition. + if (IsSpatiallyPartitioned(lhs) && + partitioned_only_along(lhs->sharding(), dot_dims->batch_dims, 0)) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } + if (IsSpatiallyPartitioned(rhs) && + partitioned_only_along(rhs->sharding(), dot_dims->batch_dims, 1)) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_rhs(), instruction); + } + if (aggressive_prop) { + // If LHS/RHS is partitioned only along the non-contracting + // dimensions, propagate the sharding to the output. + const bool can_propagate_from_lhs = + IsSpatiallyPartitioned(lhs) && + partitioned_only_along(lhs->sharding(), + dot_dims->lhs_non_contracting_dims, 0); + const bool can_propagate_from_rhs = + IsSpatiallyPartitioned(rhs) && + partitioned_only_along(rhs->sharding(), + dot_dims->rhs_non_contracting_dims, 1); + // If we can propagate from both operands, choose the larger one which + // should help us reduce communications. + if (can_propagate_from_lhs && can_propagate_from_rhs) { + if (Product(lhs->shape().dimensions()) >= + Product(rhs->shape().dimensions())) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } else { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_rhs(), instruction); + } + } + if (can_propagate_from_lhs) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } + if (can_propagate_from_rhs) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_rhs(), instruction); + } + } + } + + if (!IsSpatiallyPartitioned(lhs)) { + return false; + } + if (lhs->sharding().IsReplicated()) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + + if (IsConvolutionKernelSmall(instruction)) { + // If the kernel is small compared to the input then we can generate an + // output what is sharded the same way as the input. + const auto& tile_assignment = lhs->sharding().tile_assignment(); + if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) { + return false; + } + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } + // If the kernel is large (e.g backward convolution) then we only support + // replicated output. + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + case HloOpcode::kTranspose: { + const HloInstruction* input = instruction->operand(0); + if (!IsSpatiallyPartitioned(input)) { + return false; + } + HloSharding sharding = hlo_sharding_util::TransposeSharding( + input->sharding(), instruction->dimensions()); + return MaybeImproveInstructionSharding(sharding, instruction); + } + case HloOpcode::kReduceWindow: { + const HloInstruction* lhs = instruction->operand(0); + if (!IsSpatiallyPartitioned(lhs)) { + return false; + } + + auto has_dilation = [](const WindowDimension& dimensions) { + return dimensions.base_dilation() > 1 || + dimensions.window_dilation() > 1; + }; + if (absl::c_any_of(instruction->window().dimensions(), has_dilation)) { + VLOG(2) << "Not applying sharding to reduce window because dilatation " + "isn't supported yet: " + << instruction->ToString(); + return false; + } + return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + } + case HloOpcode::kSelectAndScatter: { + // Shard according to first operand, as output keeps the same shape. + const HloInstruction* lhs = instruction->operand(0); + if (!IsSpatiallyPartitioned(lhs)) { + return false; + } + + auto has_base_dilation = [](const WindowDimension& dimensions) { + return dimensions.base_dilation() > 1; + }; + if (absl::c_any_of(instruction->window().dimensions(), + has_base_dilation)) { + VLOG(2) << "Not applying sharding to select-and-scatter because " + "base dilation isn't supported yet: " + << instruction->ToString(); + return false; + } + return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + } + case HloOpcode::kReshape: { + if (!IsSpatiallyPartitioned(instruction->operand(0))) { + return false; + } + absl::optional new_sharding = + hlo_sharding_util::ReshapeSharding( + instruction->operand(0)->shape(), instruction->shape(), + instruction->operand(0)->sharding()); + if (new_sharding.has_value()) { + return MaybeImproveInstructionSharding(new_sharding.value(), + instruction); + } + return false; + } + case HloOpcode::kReverse: { + if (!IsSpatiallyPartitioned(instruction->operand(0))) { + return false; + } + return MaybeImproveInstructionSharding( + hlo_sharding_util::ReverseSharding( + instruction->operand(0)->sharding(), instruction->dimensions()), + instruction); + } + case HloOpcode::kDot: { + auto& dot_dim_numbs = instruction->dot_dimension_numbers(); + // Batch dimensions are the same for lhs and rhs on dot operations. + int64 num_batch_dims = dot_dim_numbs.lhs_batch_dimensions_size(); + std::vector contracting_dims(2); + contracting_dims[0] = dot_dim_numbs.lhs_contracting_dimensions(0); + contracting_dims[1] = dot_dim_numbs.rhs_contracting_dimensions(0); + std::vector ops_sharding(2, nullptr); + for (int64 op_num = 0; op_num < 2; ++op_num) { + const HloInstruction* op = instruction->operand(op_num); + if (IsSpatiallyPartitioned(op)) { + ops_sharding[op_num] = &op->sharding(); + } + } + if (ops_sharding[0] == nullptr && ops_sharding[1] == nullptr) { + return false; + } + + // Select representative operand. + int64 representative_op = -1; + if (ops_sharding[0] == nullptr) { + representative_op = 1; + } else if (ops_sharding[1] == nullptr) { + representative_op = 0; + } else if (ops_sharding[0]->IsReplicated() && + ops_sharding[1]->IsReplicated()) { + // Both replicated -> replicate + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } else if (!ops_sharding[0]->IsReplicated() && + !ops_sharding[1]->IsReplicated()) { + // Both tile sharded. The dot spatial partitioning implementation + // replicates the operand corresponding to the non-tiled dimension: + // dot(lhs, rhs), sharding={devices=[1, ..., n, 1]} replicates rhs + // dot(lhs, rhs), sharding={devices=[1, ..., 1, n]} replicates lhs + // so set sharding in order to replicate the smaller of lhs and rhs + representative_op = + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape()) + ? 1 + : 0; + } else { + // One is replicated and the other is tiled - pick the tiled one. + representative_op = ops_sharding[0]->IsReplicated() ? 1 : 0; + } + + if (ops_sharding[representative_op]->IsReplicated()) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } else { + // Tile-shard instruction according to representative op. + auto sharding = *ops_sharding[representative_op]; + if (instruction->shape().dimensions_size() != + sharding.tile_assignment().num_dimensions()) { + // It is necessarily the case of a matrix x vector, with + // representative_op being the matrix, because the vector op has the + // same shape as instruction. + CHECK_EQ(sharding.tile_assignment().num_dimensions(), + instruction->shape().dimensions_size() + 1); + // Reshape sharding so that last dimension is 1, and then remove + // last dimension. + std::vector non_batch_dims( + sharding.tile_assignment().num_dimensions() - num_batch_dims); + absl::c_iota(non_batch_dims, num_batch_dims); + sharding = hlo_sharding_util::ReshapeToTileDimension( + sharding, num_batch_dims, non_batch_dims); + auto tile_assignment = sharding.tile_assignment(); + auto dimensions = tile_assignment.dimensions(); + CHECK_EQ(dimensions.back(), 1); + dimensions.pop_back(); + tile_assignment.Reshape(dimensions); + sharding = HloSharding::Tile(tile_assignment); + } + return MaybeImproveInstructionSharding(sharding, instruction); + } + } + case HloOpcode::kParameter: { + auto parent_it = computation_map.find(instruction->parent()); + if (parent_it == computation_map.end()) { + return false; + } + const HloInstruction* parent = parent_it->second; + switch (parent->opcode()) { + case HloOpcode::kConditional: { + for (int64 i = 1; i < parent->operand_count(); ++i) { + if (parent->called_computations()[i - 1] == instruction->parent()) { + if (parent->operand(i)->has_sharding()) { + return MaybeImproveInstructionSharding( + parent->operand(i)->sharding(), instruction); + } + return false; + } + } + return false; + } + default: + return false; + } + } + case HloOpcode::kSort: { + const HloInstruction* operand = PickRepresentativeOperand(instruction); + if (!operand || !IsSpatiallyPartitioned(operand)) { + return false; + } + + if (!operand->sharding().IsTileMaximal() && + operand->sharding().tile_assignment().dim( + instruction->dimensions(0)) != 1) { + // Doesn't support sharding the sorting dimension. + return false; + } + + if (instruction->shape().IsTuple()) { + return MaybeImproveInstructionSharding( + HloSharding::SingleTuple(instruction->shape(), operand->sharding()), + instruction); + } else { + return MaybeImproveInstructionSharding(operand->sharding(), + instruction); + } + } + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: { + auto propagate_slicing = [instruction]() { + const HloInstruction* operand = + instruction->opcode() == HloOpcode::kDynamicSlice + ? instruction->operand(0) + : instruction->operand(1); + if (!IsSpatiallyPartitioned(operand)) { + return false; + } + + if (operand->sharding().IsReplicated()) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + + const auto& tile_assignment = operand->sharding().tile_assignment(); + for (int64 i = 0; i < instruction->shape().rank(); ++i) { + if (tile_assignment.dim(i) > 1 && + instruction->shape().dimensions(i) != + operand->shape().dimensions(i)) { + return false; + } + } + return MaybeImproveInstructionSharding(operand->sharding(), + instruction); + }; + auto propagate_base = [instruction]() { + if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) { + return false; + } + if (!IsSpatiallyPartitioned(instruction->operand(0))) { + return false; + } + return MaybeImproveInstructionSharding( + instruction->operand(0)->sharding(), instruction); + }; + return propagate_slicing() || propagate_base(); + } + case HloOpcode::kGather: { + if (!IsSpatiallyPartitioned(instruction->operand(1))) { + return false; + } + HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( + instruction->operand(1)->sharding(), instruction); + return MaybeImproveInstructionSharding(new_sharding, instruction); + } + case HloOpcode::kScatter: { + if (!IsSpatiallyPartitioned(instruction->operand(1)) && + !IsSpatiallyPartitioned(instruction->operand(2))) { + return false; + } + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + case HloOpcode::kWhile: { + if (!instruction->operand(0)->has_sharding()) { + return false; + } + auto sharding = instruction->operand(0)->sharding(); + if (instruction->has_sharding()) { + sharding = + MergeForMoreSpecificSharding(sharding, instruction->sharding()); + } + return MaybeImproveInstructionSharding(sharding, instruction); + } + default: { + const HloInstruction* operand = PickRepresentativeOperand(instruction); + if (!operand || !IsSpatiallyPartitioned(operand)) { + return false; + } + return MaybeImproveInstructionSharding(operand->sharding(), instruction); + } + } + return false; +} + +// Return the sharding that should be propagated from user to instruction. +absl::optional GetShardingFromUser( + const HloInstruction& instruction, const HloInstruction& user, + bool aggressive_prop, bool is_spmd) { + if (!IsSpatiallyPartitioned(&user)) { + return absl::nullopt; + } + switch (user.opcode()) { + case HloOpcode::kBroadcast: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + // Only support when none of the partitioned dimensions in the broadcast + // output belong to new dimensions. + for (int64 i = 0; i < user.shape().rank(); ++i) { + if (user.sharding().tile_assignment().dim(i) > 1 && + absl::c_count(user.dimensions(), i) == 0) { + return absl::nullopt; + } + } + + // The instruction (operand of broadcast) will be tiled the same way + // as the output. + std::vector target_tile_assignment_dimensions; + for (int64 output_dim : user.dimensions()) { + target_tile_assignment_dimensions.push_back( + user.sharding().tile_assignment().dim(output_dim)); + } + Array new_tile_assignment = user.sharding().tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(new_tile_assignment); + } + case HloOpcode::kConcatenate: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + + const int64 cdim = user.concatenate_dimension(); + const Array& tile_assignment = user.sharding().tile_assignment(); + if (tile_assignment.dim(cdim) == 1) { + // If we are concatenating along a non-sharded dimension then the + // operands should have the same sharding as the result. + return user.sharding(); + } + + if (is_spmd) { + // SPMD doesn't support tiling with part of the devices. Return the same + // sharding. + return user.sharding(); + } + + // If we are concatenating along a sharded dimension then we want the + // operands to be distributed among the devices their data is used. + int64 start_offset = 0; + for (HloInstruction* op : user.operands()) { + if (op == &instruction) { + break; + } + start_offset += op->shape().dimensions(cdim); + } + const int64 tile_shape = CeilOfRatio(user.shape().dimensions(cdim), + tile_assignment.dimensions()[cdim]); + std::vector start_indices(tile_assignment.num_dimensions()); + std::vector end_indices = tile_assignment.dimensions(); + start_indices[cdim] = start_offset / tile_shape; + end_indices[cdim] = CeilOfRatio( + start_offset + instruction.shape().dimensions(cdim), tile_shape); + auto new_tile_assignment = + tile_assignment.Slice(start_indices, end_indices); + if (new_tile_assignment.num_elements() == 1) { + return HloSharding::AssignDevice(*new_tile_assignment.begin()); + } + return HloSharding::Tile(new_tile_assignment); + } + case HloOpcode::kConvolution: { + if (auto dot_dims = + dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) { + const auto& dnums = user.convolution_dimension_numbers(); + auto partitioned_only_along = + [&](const HloSharding& sharding, + std::vector& + dims) { + if (sharding.IsTileMaximal()) { + return false; + } + int64 partition_count = 1; + for (const auto& dim : dims) { + partition_count *= sharding.tile_assignment().dim(dim.output); + } + return partition_count == + sharding.tile_assignment().num_elements(); + }; + // If output is partitioned only along the batch dimensions, or only + // along the non-contracting dimensions, propagate the sharding to the + // operand. + if (&instruction == user.operand(0) && + (partitioned_only_along(user.sharding(), dot_dims->batch_dims) || + partitioned_only_along(user.sharding(), + dot_dims->lhs_non_contracting_dims))) { + std::vector lhs_to_output_indices(user.shape().rank()); + lhs_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + lhs_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(user.sharding(), + lhs_to_output_indices); + } + if (&instruction == user.operand(1) && + (partitioned_only_along(user.sharding(), dot_dims->batch_dims) || + partitioned_only_along(user.sharding(), + dot_dims->rhs_non_contracting_dims))) { + std::vector rhs_to_output_indices(user.shape().rank()); + rhs_to_output_indices[dnums.kernel_input_feature_dimension()] = + dnums.output_batch_dimension(); + rhs_to_output_indices[dnums.kernel_output_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_output_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(user.sharding(), + rhs_to_output_indices); + } + } + return absl::nullopt; + } + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + if (user.opcode() == HloOpcode::kDynamicUpdateSlice && + &instruction == user.operand(0)) { + return user.sharding(); + } + const HloInstruction* operand = user.opcode() == HloOpcode::kDynamicSlice + ? user.operand(0) + : user.operand(1); + if (&instruction != operand) { + return absl::nullopt; + } + + const auto& tile_assignment = user.sharding().tile_assignment(); + for (int64 i = 0; i < user.shape().rank(); ++i) { + if (tile_assignment.dim(i) > 1 && + user.shape().dimensions(i) != operand->shape().dimensions(i)) { + return absl::nullopt; + } + } + return user.sharding(); + } + case HloOpcode::kReduceWindow: { + if (&instruction != user.operand(0)) { + return absl::nullopt; + } + return user.sharding(); + } + case HloOpcode::kReshape: { + return hlo_sharding_util::ReshapeSharding( + user.shape(), instruction.shape(), user.sharding()); + } + case HloOpcode::kTranspose: { + // Calculate the dimension numbers for reversing the current transpose + // and then use TransposeSharding to convert the output sharding to an + // input sharding. + std::vector reverse_dimensions(user.dimensions().size()); + for (int64 i = 0; i < user.dimensions().size(); ++i) { + reverse_dimensions[user.dimensions(i)] = i; + } + return hlo_sharding_util::TransposeSharding(user.sharding(), + reverse_dimensions); + } + case HloOpcode::kTuple: { + return user.sharding().GetSubSharding(user.shape(), + {user.operand_index(&instruction)}); + } + case HloOpcode::kGetTupleElement: { + HloSharding new_sharding = + instruction.has_sharding() + ? instruction.sharding() + : HloSharding::SingleTuple(instruction.shape(), + HloSharding::Replicate()); + int64 sharding_index = 0; + for (int64 i = 0; i < instruction.shape().tuple_shapes_size(); ++i) { + if (i == user.tuple_index()) { + break; + } + if (instruction.shape().tuple_shapes(i).IsArray()) { + sharding_index += 1; + } else { + sharding_index += + instruction.shape().tuple_shapes(i).tuple_shapes_size(); + } + } + if (user.shape().IsArray()) { + new_sharding.tuple_elements()[sharding_index] = user.sharding(); + } + for (int64 i = 0; i < user.sharding().tuple_elements().size(); ++i) { + new_sharding.tuple_elements()[sharding_index + i] = + user.sharding().tuple_elements()[i]; + } + return new_sharding; + } + case HloOpcode::kDot: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + auto& dim_numbers = user.dot_dimension_numbers(); + int64 op_idx = user.operand_index(&instruction); + // Batch dimensions are the same on lhs and rhs for dot operations. + int64 num_batch_dims = dim_numbers.lhs_batch_dimensions_size(); + int64 num_spatial_dims = + instruction.shape().dimensions_size() - num_batch_dims; + if (num_spatial_dims == 1) { + // This is the vector of a matrix x vector operation -> replicate, + // since tiling on the vector would necessarily be on the contracting + // dimension, which we don't support. + CHECK_EQ(op_idx, 1); + return HloSharding::Replicate(); + } + // Instruction is necessarily a matrix because it is one of the operands + // of a matrix x matrix operation. + CHECK_EQ(num_spatial_dims, 2); + // Propagate tile sharding to the bigger operand, and replicate the other. + auto other_op = user.operand(op_idx ^ 1); + if (ShapeUtil::ByteSizeOf(instruction.shape()) > + ShapeUtil::ByteSizeOf(other_op->shape())) { + return user.sharding(); + } else { + return HloSharding::Replicate(); + } + } + case HloOpcode::kReduce: { + if (instruction.shape().rank() == 0) { + return absl::nullopt; + } + auto user_sharding = + user.shape().IsTuple() + ? user.sharding().GetSubSharding( + user.shape(), {user.operand_index(&instruction)}) + : user.sharding(); + if (user_sharding.IsTileMaximal()) { + return user_sharding; + } + std::vector target_tile_assignment_dimensions( + instruction.shape().rank()); + const auto& dimensions = user.dimensions(); + int64 next_output_dim = 0; + for (int64 i = 0; i < instruction.shape().rank(); ++i) { + if (absl::c_find(dimensions, i) == dimensions.end()) { + target_tile_assignment_dimensions[i] = + user_sharding.tile_assignment().dim(next_output_dim++); + } else { + target_tile_assignment_dimensions[i] = 1; + } + } + auto tile_assignment = user_sharding.tile_assignment(); + tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(tile_assignment); + } + case HloOpcode::kSort: { + if (user.sharding().IsTuple()) { + return user.sharding().GetSubSharding( + user.shape(), {user.operand_index(&instruction)}); + } else { + return user.sharding(); + } + } + case HloOpcode::kReverse: { + return hlo_sharding_util::ReverseSharding(user.sharding(), + user.dimensions()); + } + default: { + // If the user output shape is compatible with the current instruction + // shape excluding element type and the current instruction is supported + // by spatial partitioning, then the user sharding can be used for + // propagation to the current instruction. + if (ShapeUtil::CompatibleIgnoringElementType(instruction.shape(), + user.shape())) { + return user.sharding(); + } + return absl::nullopt; + } + } +} + +// Tries to update the sharding of the specified instruction based on its users +// and returns true if the sharding of the instruction have been changed and +// false otherwise. +bool InferShardingFromUsers(HloInstruction* instruction, + const ComputationMap& computation_map, + bool aggressive_prop, bool is_spmd) { + if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { + return false; + } + bool improved_sharding = false; + for (const HloInstruction* user : instruction->users()) { + absl::optional user_sharding = + GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd); + if (user_sharding) { + improved_sharding |= + MaybeImproveInstructionSharding(*user_sharding, instruction); + } + } + return improved_sharding; +} + +// Remove Sharding custom-call instruction by folding the sharding attribute +// to its operand. If the operand alreayd has a different sharding, insert a +// copy node for reshard. +StatusOr ProcessShardingInstruction(HloModule* module) { + bool changed = false; + + for (HloComputation* computation : module->computations()) { + auto instructions = computation->MakeInstructionPostOrder(); + std::reverse(instructions.begin(), instructions.end()); + for (HloInstruction* instruction : instructions) { + if (instruction->opcode() != HloOpcode::kCustomCall) { + continue; + } + if (instruction->custom_call_target() != "Sharding") { + continue; + } + TF_RET_CHECK(instruction->has_sharding()) + << "Sharding instruction must have a sharding attribute"; + const HloSharding& sharding = instruction->sharding(); + + // If the operand has a different sharding from the current sharding + // instruction, create a copy node. Otherwise, just remove the sharding + // instruction and set the operand sharding. + if (instruction->operand(0)->has_sharding() && + instruction->operand(0)->sharding() != sharding) { + auto copy = computation->AddInstruction( + HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kCopy, + instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instruction, copy)); + copy->set_sharding(sharding); + } else { + instruction->mutable_operand(0)->set_sharding(sharding); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + } + changed = true; + } + } + return changed; +} + +} // namespace + +/*static*/ Status ShardingPropagation::NormalizeDomain( + const DomainMetadata::Domain& domain, const DomainMetadata* metadata) { + if (metadata != nullptr) { + TF_ASSIGN_OR_RETURN(const auto& sharding_metadata, + ShardingMetadata::ToShardingMetadata(metadata)); + const auto& sharding = sharding_metadata->sharding(); + if (sharding != nullptr) { + bool is_spatially_partitioned = !sharding->HasUniqueDevice(); + if (sharding->IsTuple()) { + is_spatially_partitioned = absl::c_any_of( + sharding->tuple_elements(), + [](const HloSharding& s) { return !s.HasUniqueDevice(); }); + } + if (is_spatially_partitioned) { + for (HloInstruction* domain : domain.exit_domains) { + domain->mutable_operand(0)->set_sharding(*sharding); + } + return Status::OK(); + } + } + } + return ShardingMetadata::NormalizeShardingDomain(domain, metadata); +} + +StatusOr ShardingPropagation::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(bool any_changed, ProcessShardingInstruction(module)); + + // Association of partitionable embedded computations with their parent + // instruction. + ComputationMap computation_map; + + // Instructions that are related through a computation and need to share the + // same sharding. + auto get_related_instructions = [](HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kWhile) { + return std::vector{ + inst, inst->while_body()->root_instruction(), + inst->while_body()->parameter_instruction(0), + inst->while_condition()->parameter_instruction(0)}; + } else if (inst->opcode() == HloOpcode::kConditional) { + std::vector comps{inst}; + for (HloComputation* c : inst->called_computations()) { + comps.push_back(c->root_instruction()); + } + return comps; + } else { + CHECK(false); + } + }; + + // If instruction is a while, or the root or a parameter of a while body, + // then propagate its sharding to the while instruction, to its body root, + // and to its condition parameter. + std::function maybe_computation_propagation = + [&](HloInstruction* instruction) { + auto propagate_to_instruction = [&](HloInstruction* search_inst) { + auto related_instructions = get_related_instructions(search_inst); + if (absl::c_count(related_instructions, instruction)) { + for (HloInstruction* inst : related_instructions) { + if (!inst->has_sharding() || + inst->sharding() != instruction->sharding()) { + VLOG(2) << "Add computation sharding: " << inst->name(); + inst->set_sharding(instruction->sharding()); + maybe_computation_propagation(inst); + } + } + } + }; + + if (instruction->opcode() == HloOpcode::kConditional || + instruction->opcode() == HloOpcode::kWhile) { + propagate_to_instruction(instruction); + } + + if (instruction->opcode() == HloOpcode::kParameter || + instruction->parent()->root_instruction() == instruction) { + auto it = computation_map.find(instruction->parent()); + if (it != computation_map.end()) { + propagate_to_instruction(it->second); + } + } + }; + + // Populate computation_map in order to associate while bodies to their + // while instructions. + for (auto computation : module->computations()) { + for (auto instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional) { + // Check if any of the related instructions has sharding, in which case + // propagate it to the other instructions, so they all share the same + // sharding, in case the user didn't shard all of them. We don't check + // that user shardings are consistent, because such check is already + // done by HloShardingVerifier. + const HloInstruction* sharded_inst = nullptr; + auto related_instructions = get_related_instructions(instruction); + for (auto inst : related_instructions) { + if (inst->has_sharding()) { + sharded_inst = inst; + break; + } + } + if (sharded_inst != nullptr) { + // Set the same sharding to all the other related instructions. + for (auto inst : related_instructions) { + inst->set_sharding(sharded_inst->sharding()); + } + } + } + if (instruction->opcode() == HloOpcode::kWhile) { + computation_map[instruction->while_body()] = instruction; + } else if (instruction->opcode() == HloOpcode::kConditional) { + for (HloComputation* c : instruction->called_computations()) { + computation_map[c] = instruction; + } + } + } + } + + // Collect all pre-sharded instructions as we aren't allowed to modify their + // sharding. + absl::flat_hash_set provided_shardings; + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* inst : computation->instructions()) { + if (inst->has_sharding()) { + provided_shardings.insert(inst); + } + } + } + + // Consider the root instruction of the entry module as one with provided + // sharding as its sharding have to match with the one expected by the host. + provided_shardings.insert(module->entry_computation()->root_instruction()); + + // Iterate to a fixpoint that is guaranteed to be reached because we only + // strictly improve the sharding of the graph and it can't be improved + // indefinitely. + int64 iterations = 0; + auto run_to_fix_point = [&](bool aggressive_prop) { + bool changed = true; + while (changed) { + changed = false; + int64 inferred_from_operand_counter = 0; + int64 inferred_from_user_counter = 0; + int64 instruction_counter = 0; + int64 already_sharded_counter = 0; + for (const HloComputation* computation : module->computations()) { + std::vector instructions = + computation->MakeInstructionPostOrder(); + + instruction_counter += instructions.size(); + for (const HloInstruction* instruction : instructions) { + already_sharded_counter += (instruction->has_sharding() ? 1 : 0); + } + + // Remove the instructions where the sharding was provided from the + // outside so we don't modify them. + instructions.erase( + std::remove_if(instructions.begin(), instructions.end(), + [&](HloInstruction* instruction) { + return provided_shardings.contains(instruction); + }), + instructions.end()); + + // First iterate the HLO graph in post order taking shardings from + // operands. + for (HloInstruction* instruction : instructions) { + if (InferShardingFromOperands(instruction, computation_map, is_spmd_, + aggressive_prop)) { + ++inferred_from_operand_counter; + changed = true; + VLOG(2) << "Add sharding (forward-pass): " + << instruction->ToString(); + maybe_computation_propagation(instruction); + } + } + + // Then iterate the HLO graph in reverse post order taking shardings + // from users. + for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { + if (InferShardingFromUsers(*it, computation_map, aggressive_prop, + is_spmd_)) { + ++inferred_from_user_counter; + changed = true; + VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); + maybe_computation_propagation(*it); + } + } + } + any_changed |= changed; + VLOG(1) << "Sharding propagation iteration " << iterations << ";"; + VLOG(1) << " total instructions: " << instruction_counter; + VLOG(1) << " instructions already sharded: " << already_sharded_counter; + VLOG(1) << " shardings inferred from operands: " + << inferred_from_operand_counter; + VLOG(1) << " shardings inferred from users: " + << inferred_from_user_counter; + ++iterations; + } + }; + run_to_fix_point(false); + run_to_fix_point(true); + + VLOG(1) << "Sharding propagation completed after " << iterations + << " iterations"; + return any_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/sharding_propagation.h b/tensorflow/compiler/xla/service/sharding_propagation.h new file mode 100644 index 00000000000..2c07a4a6a31 --- /dev/null +++ b/tensorflow/compiler/xla/service/sharding_propagation.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Propagates sharding information around the graph. HLOs that have shardings +// are kept as-is, those that do not have shardings are given shardings based on +// a simple local greedy heuristic. +class ShardingPropagation : public HloModulePass { + public: + explicit ShardingPropagation(bool is_spmd = false) : is_spmd_(is_spmd) {} + absl::string_view name() const override { return "sharding-propagation"; } + StatusOr Run(HloModule* module) override; + + // Function which can be used to apply a spatially partitioned sharding onto a + // given domain. It will apply the sharding into the exit edges of the domain + // and then rely on the rest of sharding propagation to ensure that the + // intermediate nodes get the correct sharding. + static Status NormalizeDomain(const DomainMetadata::Domain& domain, + const DomainMetadata* metadata); + + private: + bool is_spmd_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc new file mode 100644 index 00000000000..a9d685a7a93 --- /dev/null +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -0,0 +1,1329 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sharding_propagation.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using ShardingPropagationTest = HloTestBase; + +TEST_F(ShardingPropagationTest, ElementwiseOperationForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %elementwise { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[1,2,2,1]0,1,2,3} + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1) + %add = f32[5,7,11,13]{3,2,1,0} add(%param0, %param1) + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ElementwiseOperationBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %elementwise { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1) + %add = f32[5,7,11,13]{3,2,1,0} add(%param0, %param1) + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%add), + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, BroadcastForwardPassNoSharding) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[7,11]{1,0} parameter(0), + sharding={devices=[2,2]0,1,2,3} + %broadcast = f32[5,7,11,13]{3,2,1,0} broadcast(%param0), dimensions={1,2} + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%broadcast) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); +} + +// Regression Test for b/129569657. +TEST_F(ShardingPropagationTest, BroadcastForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[3,2048,2048]{2,1,0} parameter(0), + sharding={devices=[1,2,2]0,1,2,3} + %broadcast = f32[3,2048,2048,3]{3,2,1,0} broadcast(%param0), dimensions={0,1,2} + ROOT %copy = f32[3,2048,2048,3]{3,2,1,0} copy(%broadcast) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "broadcast"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, BroadcastBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[13]{0} parameter(0) + %broadcast = f32[5,7,11,13]{3,2,1,0} broadcast(%param0), dimensions={3} + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%broadcast), + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "broadcast"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, BroadcastUser) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[24,8]{0,1} parameter(0) + %copy = f32[24,8]{0,1} copy(%param0) + ROOT %broadcast = f32[4,24,6,8]{3,2,1,0} broadcast(%copy), dimensions={1,3}, + sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,4]0,1,2,3,4,5,6,7}")); +} + +TEST_F(ShardingPropagationTest, MaximalReduceForwardPass) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[1,2,2,1]0,1,2,3} + %init = f32[] parameter(1) + %reduce = f32[5,7]{1,0} reduce(%param0, %init), dimensions={2,3}, to_apply=%add + ROOT %copy = f32[5,7]{0,1} copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, ShardedReduceForwardPass) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[1,2,2,1]0,1,2,3} + %init = f32[] parameter(1) + %reduce = f32[7,11]{1,0} reduce(%param0, %init), dimensions={0,3}, to_apply=%add + ROOT %copy = f32[7,11]{0,1} copy(f32[7,11]{1,0} %reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{devices=[2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ShardedTupleReduceForwardAndBackwardPass) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0) + %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1} + %copy_param0 = f32[28,10] copy(%param0) + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + %reduce = (f32[28], s32[28]) reduce(%copy_param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func + %gte0 = f32[28] get-tuple-element(%reduce), index=0 + %gte1 = s32[28] get-tuple-element(%reduce), index=1 + %copy0 = f32[28] copy(%gte0) + %copy1 = s32[28] copy(%gte1) + ROOT %tuple = (f32[28], s32[28]) tuple(%copy0, %copy1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{{devices=[2]0,1},{devices=[2]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "copy_param0"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, GetTupleElementForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %gte { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) + %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple( + %param0, %param0) + %tuple.1 = (f32[5,7,11,13]{3,2,1,0}, + (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0})) tuple( + %param0, %tuple), + sharding={{devices=[1,2,2,1]0,1,2,3}, + {replicated}, + {devices=[1,2,2,1]0,1,2,3}} + %gte = f32[5,7,11,13]{3,2,1,0} get-tuple-element(%tuple.1), index=0 + %gte.1 = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) get-tuple-element( + %tuple.1), index=1 + %gte.2 = f32[5,7,11,13]{3,2,1,0} get-tuple-element(%gte.1), index=0 + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%gte.2) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gte"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "gte.1"), + op::Sharding("{{replicated}," + " {devices=[1,2,2,1]0,1,2,3}}")); + EXPECT_THAT(FindInstruction(module.get(), "gte.2"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, TupleForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %tuple { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={replicated} + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1), + sharding={devices=[1,2,2,1]0,1,2,3} + %param2 = f32[5,7,11,13]{3,2,1,0} parameter(2) + %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple( + %param1, %param2) + %tuple.1 = (f32[5,7,11,13]{3,2,1,0}, + (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0})) tuple( + %param0, %tuple) + ROOT %copy = (f32[5,7,11,13]{3,2,1,0}, + (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0})) copy( + %tuple.1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tuple"), + op::Sharding("{{devices=[1,2,2,1]0,1,2,3}," + " {replicated}}")); + EXPECT_THAT(FindInstruction(module.get(), "tuple.1"), + op::Sharding("{{replicated}," + " {devices=[1,2,2,1]0,1,2,3}," + " {replicated}}")); +} + +TEST_F(ShardingPropagationTest, ForwardConvolutionForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %lhs = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7} + %rhs = f32[3,3,13,17]{3,2,1,0} parameter(1) + %convolution = f32[5,7,11,17]{3,2,1,0} convolution(%lhs, %rhs), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f + ROOT %copy = f32[5,7,11,17]{3,2,1,0} copy(%convolution) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "convolution"), + op::Sharding("{devices=[2,2,2,1]0,1,2,3,4,5,6,7}")); +} + +TEST_F(ShardingPropagationTest, ForwardConvolutionLargeDilationForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %lhs = f32[8,64,2]{2,1,0} parameter(0), + sharding={devices=[1,4,1]0,1,2,3} + %rhs = f32[3,2,2]{2,1,0} parameter(1) + %convolution = f32[8,32,2]{2,1,0} convolution(%lhs, %rhs), + window={size=3 rhs_dilate=16}, dim_labels=b0f_0io->b0f + ROOT %copy = f32[8,32,2]{2,1,0} copy(%convolution) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "convolution"), + op::Sharding("{devices=[1,4,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, TransposeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0), + sharding={devices=[2,1,2]0,1,2,3} + %transpose = f32[11,13,7]{2,1,0} transpose(%param), dimensions={1,2,0} + ROOT %copy = f32[11,13,7]{2,1,0} copy(%transpose) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "transpose"), + op::Sharding("{devices=[1,2,2]0,2,1,3}")); +} + +TEST_F(ShardingPropagationTest, TransposeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0) + %copy = f32[7,11,13]{2,1,0} copy(%param) + ROOT %transpose = f32[11,13,7]{2,1,0} transpose(%copy), dimensions={1,2,0}, + sharding={devices=[1,2,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1,2]0,2,1,3}")); +} + +TEST_F(ShardingPropagationTest, ReshapeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[1430,1]{1,0} parameter(0), + sharding={devices=[2,1]0,1} + %reshape = f32[10,11,13]{2,1,0} reshape(%param0) + ROOT %copy = f32[10,11,13]{2,1,0} copy(%reshape) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reshape"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, ReshapeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[2002,1]{1,0} parameter(0) + %copy = f32[2002,1]{1,0} copy(f32[2002,1]{1,0} %param0) + ROOT %reshape = f32[14,11,13]{2,1,0} reshape(%copy), + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, PadForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %pad { + %input = f32[11,17]{1,0} parameter(0), + sharding={devices=[2,2]0,1,2,3} + %pad_value = f32[] parameter(1) + %pad = f32[27,51]{1,0} pad(%input, %pad_value), padding=2_4_1x1_1_2 + ROOT %copy = f32[27,51]{1,0} copy(%pad) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "pad"), + op::Sharding("{devices=[2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ShardedPreferredOverReplicated) { + const char* const hlo_string = R"( +HloModule module +ENTRY %replicated { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={replicated} + %copy = f32[5,7,11,13]{3,2,1,0} copy(%param0) + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1), + sharding={devices=[1,2,2,1]0,1,2,3} + %copy.1 = f32[5,7,11,13]{3,2,1,0} copy(%param1) + %add = f32[5,7,11,13]{3,2,1,0} add(%copy, %copy.1) + ROOT %copy.2 = f32[5,7,11,13]{3,2,1,0} copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "copy.1"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, DontShardTuplesIfAllInputIsMaximal) { + const char* const hlo_string = R"( +HloModule module +ENTRY %tuple { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={maximal device=0} + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1), + sharding={maximal device=1} + %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple( + %param0, %param1) + ROOT %copy = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) copy(%tuple) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tuple"), op::NoSharding()); +} + +TEST_F(ShardingPropagationTest, ValidConvolution) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[13,17,19]{2,1,0} parameter(0), + sharding={devices=[1,2,1]0,1} + %rhs = f32[19,5,19]{2,1,0} parameter(1) + %conv = f32[13,13,19]{2,1,0} convolution(%lhs, %rhs), + window={size=5}, dim_labels=b0f_i0o->b0f + ROOT %tuple = (f32[13,13,19]{2,1,0}) tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, StridedSlice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %slice { + %param = f32[17,13]{1,0} parameter(0), + sharding={devices=[2,1]0,1} + %slice = f32[7,5]{1,0} slice(%param), slice={[1:15:2], [5:10:1]} + ROOT %tuple = (f32[7,5]{1,0}) tuple(%slice) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "slice"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, ReduceWindowBackwardPass) { + const char* const hlo_string = R"( +HloModule module +%add (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce_window { + %param = f32[13,17]{1,0} parameter(0) + %param.copy = f32[13,17]{1,0} copy(%param) + %init = f32[] parameter(1) + ROOT %reduce-window = f32[7,17]{1,0} reduce-window(%param.copy, %init), + window={size=3x2 stride=2x1 pad=1_1x0_1}, to_apply=%add, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "param.copy"), + op::Sharding("{devices=[2,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "reduce-window"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, ReplicatedConvolutionLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[3,2,3]{2,1,0} parameter(0), sharding={replicated} + %rhs = f32[2,2,1]{2,1,0} parameter(1) + %conv = f32[3,2,3]{2,1,0} convolution(%lhs, %rhs), + window={size=1}, dim_labels=bf0_oi0->bf0 + ROOT %tuple = f32[3,2,3]{2,1,0} tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs"), + op::Sharding("{replicated}")); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, ConvolutionShardedFeature) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[3,2,3]{2,1,0} parameter(0), + sharding={devices=[1,2,1]0,1} + %rhs = f32[2,2,1]{2,1,0} parameter(1) + %conv = f32[3,2,3]{2,1,0} convolution(%lhs, %rhs), + window={size=1}, dim_labels=bf0_oi0->bf0 + ROOT %tuple = f32[3,2,3]{2,1,0} tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(ShardingPropagationTest, ConvolutionDifferentDimensionNumbers) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[8,16,512] parameter(0), + sharding={devices=[1,2,1]0,1} + %rhs = f32[8,2,512] parameter(1) + %conv = f32[3,512,512] convolution(%lhs, %rhs), + window={size=2 stride=5}, + dim_labels=f0b_i0o->0bf + ROOT %tuple = f32[3,512,512] tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, Concatenate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %concat { + %param.0 = f32[5,7] parameter(0), + sharding={devices=[2,1]0,1} + %param.1 = f32[5,9] parameter(1), + sharding={devices=[2,1]0,1} + %concat = f32[5,16] concatenate(%param.0, %param.1), + dimensions={1} + ROOT %tuple = (f32[5,16]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "concat"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, TupleBackwardPass) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %tuple { + %param.0 = f32[1] parameter(0) + %param.1 = f32[3] parameter(1) + %copy.0 = f32[1] copy(%param.0) + %copy.1 = f32[3] copy(param.1) + ROOT %tuple = (f32[1], f32[3]) tuple(%copy.0, %copy.1), + sharding={{replicated}, {devices=[2]0,1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy.0"), + op::Sharding("{replicated}")); + EXPECT_THAT(FindInstruction(module.get(), "copy.1"), + op::Sharding("{devices=[2]0,1}")); +} + +TEST_F(ShardingPropagationTest, AllReduce) { + const char* const hlo_string = R"( +HloModule module + +%add (lhs: f32[], rhs: f32[]) -> f32[] { + %add_lhs = f32[] parameter(0) + %add_rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %add_lhs, f32[] %add_rhs) +} + +ENTRY %entry { + %param.0 = f32[3] parameter(0) + %param.1 = f32[3] parameter(1) + + %copy_f_t = f32[3] copy(%param.1), sharding={devices=[2]0,1} + %crs_f.tiled = f32[3] all-reduce(%copy_f_t), to_apply=%add + %crs_f.none = f32[3] all-reduce(%copy_f_t), to_apply=%add, + channel_id=1 + + %crs_b.replicated = f32[3] all-reduce(%param.0), to_apply=%add + %copy_b_r = f32[3] copy(%crs_b.replicated), sharding={replicated} + + ROOT %tuple = (f32[3], f32[3], f32[3], f32[3]) tuple( + %crs_f.tiled, crs_f.none, %copy_b_r) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "crs_f.tiled"), + op::Sharding("{devices=[2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "crs_f.none"), op::NoSharding()); + + EXPECT_THAT(FindInstruction(module.get(), "crs_b.replicated"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, While) { + const char* const hlo_string = R"( +HloModule module + +%cond { + %vars.cond = (u32[], f32[10]{0}) parameter(0) + %count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0 + %limit = u32[] constant(10) + ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT +} + +%body { + %vars = (u32[], f32[10]{0}) parameter(0) + %count = u32[] get-tuple-element(%vars), index=0 + %acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1 + + %one = u32[] constant(1) + %count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated} + %acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc) + ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1) +} + +ENTRY %entry { + %p0 = f32[10]{0} parameter(0) + %p0.copy = f32[10]{0} copy(f32[10]{0} %p0) + %p1 = f32[10]{0} parameter(1) + %zero = u32[] constant(0) + %init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy) + %while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init), + body=%body, condition=%cond + %res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1 + %prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1 + %res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev) + ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1) +})"; + + auto while_is_sharded = [this](HloModule* module, + const HloSharding& sharding) { + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module)); + EXPECT_TRUE(changed); + auto while_instr = FindInstruction(module, "while"); + EXPECT_NE(nullptr, while_instr); + std::vector instructions{ + while_instr, while_instr->while_body()->root_instruction(), + while_instr->while_body()->parameter_instruction(0), + while_instr->while_condition()->parameter_instruction(0)}; + + for (auto instr : instructions) { + EXPECT_TRUE(instr->has_sharding()); + EXPECT_EQ(sharding, instr->sharding()); + } + }; + { + // Propagation of user-defined partial sharding of while-related instruction + // (body root in this test). + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto body_root = FindInstruction(module.get(), "tuple"); + EXPECT_NE(nullptr, body_root); + auto sharding = + ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie(); + body_root->set_sharding(sharding); + while_is_sharded(module.get(), sharding); + } + { + // Propagation from acc.1 to the rest of the loop. + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto acc_1 = FindInstruction(module.get(), "acc.1"); + EXPECT_NE(nullptr, acc_1); + acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie()); + + while_is_sharded( + module.get(), + ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie()); + } +} + +TEST_F(ShardingPropagationTest, Dot) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %param.0 = f32[8,256,128] parameter(0) + %param.1 = f32[8,128,512] parameter(1) + %param.2 = f32[8,128] parameter(2) + + %p0_copy_0 = f32[8,256,128] copy(%param.0), + sharding={devices=[1,4,1]0,1,2,3} + %p1_copy_0 = f32[8,128,512] copy(%param.1), + sharding={devices=[1,2,2]0,1,2,3} + %p2_copy = f32[8,128] copy(%param.2) + %dot_prop_rhs = f32[8,256,512] dot(%p0_copy_0, %p1_copy_0), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + %dot_prop_lhs = f32[8,512,256] dot(%p1_copy_0, %p0_copy_0), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_contracting_dims={2} + %dot_mat_vec = f32[8,256] dot(%p0_copy_0, %p2_copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + + %p0_copy_1 = f32[8,256,128] copy(%param.0) + %p1_copy_1 = f32[8,128,512] copy(%param.1) + %dot_back_prop_rhs = f32[8,256,512] dot(%p0_copy_1, %p1_copy_1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + %copy_back_prop_rhs = f32[8,256,512] copy(%dot_back_prop_rhs), + sharding={devices=[1,2,2]0,1,2,3} + + ROOT %tuple = (f32[8,256,256], f32[8,256,256], f32[8,256]) + tuple(%dot_prop_lhs, %dot_prop_rhs, %dot_mat_vec, %copy_back_prop_rhs) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "dot_prop_rhs"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "dot_prop_lhs"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "dot_mat_vec"), + op::Sharding("{devices=[1,4]0,1,2,3}")); + + EXPECT_THAT(FindInstruction(module.get(), "p0_copy_1"), + op::Sharding("{replicated}")); + EXPECT_THAT(FindInstruction(module.get(), "p1_copy_1"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "dot_back_prop_rhs"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, DotTiledBatchDim) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,256,512] parameter(0) + %p1 = f32[8,512,128] parameter(1) + + %add = f32[8,256,512] add(%p0, %p0) + %dot = f32[8,256,128] dot(%add, %p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + %res = f32[8,32768] reshape(%dot), sharding={devices=[2,2]0,1,2,3} + + ROOT %tuple = (f32[8,32768]) tuple(%res) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ConcatFromUserUnshardedDim) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,128] parameter(0) + %p1 = f32[8,128] parameter(1) + %c0 = f32[8,128] copy(%p0) + %c1 = f32[8,128] copy(%p1) + + %concat = f32[16,128] concatenate(%c0, %c1), + dimensions={0}, + sharding={devices=[1,2]0,1} + ROOT %tuple = (f32[16,128]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, ConcatFromUserShardedDim) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,128] parameter(0) + %p1 = f32[8,128] parameter(1) + %c0 = f32[8,128] copy(%p0) + %c1 = f32[8,128] copy(%p1) + + %concat = f32[16,128] concatenate(%c0, %c1), + dimensions={0}, + sharding={devices=[3,1]0,1,2} + ROOT %tuple = (f32[16,128]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[2,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[2,1]1,2}")); +} + +TEST_F(ShardingPropagationTest, ConcatFromUserShardedDimMaximalOperand) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,128] parameter(0) + %p1 = f32[24,128] parameter(1) + %c0 = f32[8,128] copy(%p0) + %c1 = f32[24,128] copy(%p1) + + %concat = f32[32,128] concatenate(%c0, %c1), + dimensions={0}, + sharding={devices=[4,1]0,1,2,3} + ROOT %tuple = (f32[32,128]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), op::NoSharding()); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[3,1]1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ReplicatedToSideEffecting) { + const char* const hlo_string = R"( +HloModule module +ENTRY entry_computation { + %const.0 = s32[] constant(0), sharding={replicated} + %const.1 = s32[] constant(2147483647), sharding={replicated} + %rng = s32[4]{0} rng(%const.0, %const.1), + distribution=rng_uniform + ROOT %root = (s32[4]{0}) tuple(%rng) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(FindInstruction(module.get(), "rng"), op::NoSharding()); +} + +TEST_F(ShardingPropagationTest, PartReplicatedTupleUser) { + const char* const hlo_string = R"( +HloModule module +ENTRY entry_computation { + %param.0 = f32[5] parameter(0) + %param.1 = f32[7] parameter(1) + %param.2 = f32[9] parameter(2) + %tuple.0 = (f32[5], f32[7]) tuple(%param.0, %param.1) + ROOT %tuple.1 = ((f32[5], f32[7]), f32[9]) tuple(%tuple.0, %param.2), + sharding={{maximal device=0}, {replicated}, {maximal device=1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tuple.0"), + op::Sharding("{{maximal device=0}, {replicated}}")); +} + +TEST_F(ShardingPropagationTest, Conditional) { + const char* const hlo_string = R"( +HloModule module + +%true_comp { + %tp = (f32[3,5]) parameter(0) + %tgte = f32[3,5] get-tuple-element(%tp), index=0 + %ttr = f32[5,3] transpose(%tgte), dimensions={1,0} + ROOT %tr = (f32[5,3]) tuple(%ttr) +} + +%false_comp { + %fp = (f32[5,3]) parameter(0) + %fgte = f32[5,3] get-tuple-element(%fp), index=0 + ROOT %fr = (f32[5,3]) tuple(%fgte) +} + +ENTRY entry { + %cond = pred[] parameter(0) + %true_param = (f32[3,5]) parameter(1), sharding={{devices=[1,2]0,1}} + %false_param = (f32[5,3]) parameter(2), sharding={{devices=[1,3]0,1,2}} + %conditional = (f32[5,3]) conditional( + %cond, %true_param, %false_param), + true_computation=%true_comp, + false_computation=%false_comp + ROOT %root = f32[5,3] get-tuple-element(%conditional), index=0 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tp"), + op::Sharding("{{devices=[1,2]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "tgte"), + op::Sharding("{devices=[1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "ttr"), + op::Sharding("{devices=[2,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "tr"), + op::Sharding("{{devices=[2,1]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "fp"), + op::Sharding("{{devices=[1,3]0,1,2}}")); + EXPECT_THAT(FindInstruction(module.get(), "fgte"), + op::Sharding("{devices=[1,3]0,1,2}")); + EXPECT_THAT(FindInstruction(module.get(), "fr"), + op::Sharding("{{devices=[2,1]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "conditional"), + op::Sharding("{{devices=[2,1]0,1}}")); +} + +TEST_F(ShardingPropagationTest, TupleFromUser) { + const char* const hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[13] parameter(0) + %p1 = f32[15] parameter(1) + %p2 = f32[17] parameter(2) + %t0 = (f32[13], f32[15]) tuple(%p0, %p1) + %t1 = ((f32[13], f32[15]), f32[17]) tuple(%t0, %p2) + %gte.0 = (f32[13], f32[15]) get-tuple-element(%t1), index=0 + %gte.1 = f32[13] get-tuple-element(%gte.0), index=0 + %gte.2 = f32[15] get-tuple-element(%gte.0), index=1 + %gte.3 = f32[17] get-tuple-element(%t1), index=1 + ROOT %t2 = (f32[13], f32[15], f32[17]) tuple(%gte.1, %gte.2, %gte.3), + sharding={{replicated}, {devices=[2]0,1}, {devices=[3]1,2,3}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "t0"), + op::Sharding("{{replicated}, {devices=[2]0,1}}")); + EXPECT_THAT( + FindInstruction(module.get(), "t1"), + op::Sharding("{{replicated}, {devices=[2]0,1}, {devices=[3]1,2,3}}")); +} + +TEST_F(ShardingPropagationTest, DynamicSliceForwardPass) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0), sharding={devices=[1,1,2]0,1} + %p1 = s32[] parameter(1) + %i0 = s32[] constant(0) + %ds = f32[11,1,15] dynamic-slice(%c0, %i0, %p1, %i0), + dynamic_slice_sizes={11,1,15} + ROOT %root = (f32[11,1,15]) tuple(%ds) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "ds"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicSliceBackwardPass) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0) + %p1 = s32[] parameter(1) + %i0 = s32[] constant(0) + %ds = f32[11,1,15] dynamic-slice(%c0, %i0, %p1, %i0), + dynamic_slice_sizes={11,1,15}, + sharding={devices=[1,1,2]0,1} + ROOT %root = (f32[11,1,15]) tuple(%ds) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "ds"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicUpdateSliceForwardPassBase) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0), sharding={devices=[1,1,2]0,1} + %p1 = f32[11,1,15] parameter(1) + %c1 = f32[11,1,15] copy(%p1) + %p2 = s32[] parameter(2) + %i0 = s32[] constant(0) + %dus = f32[11,13,15] dynamic-update-slice(%c0, %c1, %i0, %p2, %i0) + ROOT %root = (f32[11,13,15]) tuple(%dus) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "dus"), + op::Sharding("{devices=[1,1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicUpdateSliceForwardPassUpdate) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0) + %p1 = f32[11,1,15] parameter(1) + %c1 = f32[11,1,15] copy(%p1), sharding={devices=[1,1,2]0,1} + %p2 = s32[] parameter(2) + %i0 = s32[] constant(0) + %dus = f32[11,13,15] dynamic-update-slice(%c0, %c1, %i0, %p2, %i0) + ROOT %root = (f32[11,13,15]) tuple(%dus) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "dus"), + op::Sharding("{devices=[1,1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicUpdateSliceBackwardPass) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0) + %p1 = f32[11,1,15] parameter(1) + %c1 = f32[11,1,15] copy(%p1) + %p2 = s32[] parameter(2) + %i0 = s32[] constant(0) + %dus = f32[11,13,15] dynamic-update-slice(%c0, %c1, %i0, %p2, %i0), + sharding={devices=[1,1,2]0,1} + ROOT %root = (f32[11,13,15]) tuple(%dus) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[1,1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, EinsumLHSBatchPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs) + %conv = f32[32,24,39296] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf_0oi->0bf, window={size=32 stride=31 lhs_dilate=32} + ROOT %copy = f32[32,24,39296] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "rhs.copy"), + op::Sharding("{devices=[2,1,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, EinsumOutputBatchPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs) + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs) + %conv = f32[32,24,39296] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf_0oi->0bf, window={size=32 stride=31 lhs_dilate=32}, + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs.copy"), + op::Sharding("{devices=[2,1,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "rhs.copy"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, EinsumLHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[32,39296,64,1] parameter(1) + %rhs.copy = f32[32,39296,64,1] copy(%rhs) + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, window={size=32x1 stride=31x1 lhs_dilate=32x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumOutputLHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs) + %rhs = f32[32,39296,64,1] parameter(1) + %rhs.copy = f32[32,39296,64,1] copy(%rhs) + ROOT %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, window={size=32x1 stride=31x1 lhs_dilate=32x1}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs.copy"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumRHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs) + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,1,2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumOutputRHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs) + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs) + ROOT %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1}, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "rhs.copy"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumChooseLargerOperand) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,1,2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumChooseBatchFirst) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[2,1,1,1]0,1}")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD new file mode 100644 index 00000000000..4433078472d --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -0,0 +1,71 @@ +# Description: SPMD partitioning pass. + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +cc_library( + name = "spmd_partitioner", + srcs = [ + "spmd_partitioner.cc", + "spmd_partitioner_util.cc", + ], + hdrs = [ + "spmd_partitioner.h", + "spmd_partitioner_util.h", + ], + deps = [ + "//tensorflow/compiler/xla:comparison_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/service:dot_as_convolution_util", + "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_query", + "//tensorflow/compiler/xla/service:hlo_sharding_util", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/core/platform:numbers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "spmd_partitioner_test", + srcs = ["spmd_partitioner_test.cc"], + deps = [ + ":spmd_partitioner", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc new file mode 100644 index 00000000000..635446a18a1 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -0,0 +1,4943 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/numbers.h" + +namespace xla { +namespace spmd { + +string SpmdLogger::MakeReport() { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory during transformation *****\n"); + + std::sort(entries_.begin(), entries_.end(), + [](auto const& entry0, auto const& entry1) { + return entry0.first > entry1.first; + }); + for (int64 i = 0; + i < std::min(report_instruction_count_, entries_.size()); ++i) { + absl::StrAppend( + &report, "\n ", + tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ", + entries_[i].second, "\n"); + } + + return report; +} + +void SpmdLogger::RegisterLogEntry(HloInstruction* hlo, + const std::vector& group) { + string report = hlo->ToString(); + int64 max_value = -1; + for (HloInstruction* inst : group) { + if (!inst->shape().IsArray()) { + continue; + } + max_value = std::max(max_value, ShapeSizeInBytes(inst->shape())); + absl::StrAppend(&report, " * ", inst->ToString(), "\n"); + } + entries_.push_back(std::make_pair(max_value, report)); +} + +/* static */ string SpmdLogger::ReportBeforePartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage before partition *****\n"); + absl::StrAppend(&report, "\n ** Replicated instructions\n"); + absl::StrAppend(&report, ReportMemoryUsage( + module, + [](const HloInstruction* hlo) { + return !hlo->has_sharding() || + hlo->sharding().IsReplicated(); + }, + report_instruction_count)); + absl::StrAppend(&report, "\n ** All instructions\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +/* static */ string SpmdLogger::ReportAfterPartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage after partition *****\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +template +/* static */ string SpmdLogger::ReportMemoryUsage( + const HloModule& module, const F& filter, int64 report_instruction_count) { + string report; + std::vector instructions; + instructions.reserve(module.instruction_count()); + + for (auto computation : module.computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto hlo : computation->instructions()) { + if (hlo->shape().IsTuple() || + ShapeUtil::IsEffectiveScalar(hlo->shape())) { + continue; + } + if (filter(hlo)) { + instructions.push_back(hlo); + } + } + } + + const auto add_report = [&](std::vector* insts) { + std::sort(insts->begin(), insts->end(), + [](const HloInstruction* inst0, const HloInstruction* inst1) { + return ShapeSizeInBytes(inst0->shape()) > + ShapeSizeInBytes(inst1->shape()); + }); + for (int64 i = 0; + i < std::min(report_instruction_count, insts->size()); ++i) { + absl::StrAppend(&report, " ", + tensorflow::strings::HumanReadableNumBytes( + ShapeSizeInBytes((*insts)[i]->shape())), + " : ", (*insts)[i]->ToString(), "\n"); + } + }; + + add_report(&instructions); + return report; +} + +namespace { + +// Returns the replica group configuration where each replica belongs to its own +// group. +std::vector CreateReplicaGroups(int64 num_replicas) { + std::vector groups(num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + groups[i].add_replica_ids(i); + } + return groups; +} + +bool CanReshardWithAllToAll(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) != UniqueTiledDim(target); +} + +bool CanReshardWithCollectivePermute(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) == UniqueTiledDim(target) && source != target; +} + +// Clears all sharding attributes from instructions in the module. This must be +// called only after all SPMD transformation is complete. +Status ClearShardingAttributes(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + // Keep sharding annotation on Infeed and entry parameters since they're + // used by HloReplicationAnalysis later (for ArCrsCombiner). + if (hlo->opcode() == HloOpcode::kInfeed) { + continue; + } + if (hlo->opcode() == HloOpcode::kParameter && + computation == module->entry_computation()) { + continue; + } + hlo->clear_sharding(); + } + } + return Status::OK(); +} + +} // namespace + +HloInstruction* SpmdBuilder::AddInstruction( + std::unique_ptr instruction) { + HloInstruction* hlo = + HloComputation::Builder::AddInstruction(std::move(instruction)); + if (visiting_hlo_) { + instructions_[visiting_hlo_].push_back(hlo); + } + return hlo; +} + +PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first == target) { + return entry.second; + } + } + cache.emplace_back(target, ReshardNoCache(target)); + state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()] + .reshard_cache.emplace_back(sharding(), *this); + return cache.back().second; +} + +PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { + VLOG(2) << "Resharding " << hlo_->ToString() << " from " + << hlo_->sharding().ToString() << " to " << target.ToString(); + const Shape& shape = hlo_->shape(); + CHECK(shape.IsTuple() || !target.IsTuple()); + + // Tuple shape instructions may have non-tuple sharding, which means that the + // same sharding applies to all the leaves. + if (shape.IsTuple() && !target.IsTuple()) { + return Reshard(target.GetTupleSharding(shape).ValueOrDie()); + } + + // For a tuple shape, recursively apply Reshard to all the leaves and return + // a tuple instruction. + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + auto subshape = ShapeUtil::GetTupleElementShape(shape, i); + auto element = state_.b->AddInstruction( + HloInstruction::CreateGetTupleElement(subshape, hlo(), i)); + element->set_sharding(sharding().GetSubSharding(shape, {i})); + elements.push_back( + PartitionedHlo( + element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_) + .Reshard(target.GetSubSharding(shape, {i})) + .hlo()); + } + auto tuple = + state_.b->AddInstruction(HloInstruction::CreateTuple(elements)); + tuple->set_sharding(target); + return PartitionedHlo(tuple, base_shape_, state_); + } + + if (sharding() == target) { + return *this; + } + + if (shape.element_type() == TOKEN) { + return *this; + } + + if (CanReshardWithCollectivePermute(sharding(), target)) { + return ReshardWithCollectivePermute(target); + } + + if (CanReshardWithAllToAll(sharding(), target)) { + return ReshardWithAllToAll(target); + } + + // If not replicated yet, first replicate and then reshard to use one of the + // two implementations below. + if (!sharding().IsReplicated()) { + return Replicate().Reshard(target); + } + + // 'Replicated' to 'SingleDevice'. + if (target.IsTileMaximal()) { + auto copy = state_.b->AddInstruction( + HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_)); + copy->set_sharding(target); + return PartitionedHlo(copy, base_shape_, state_); + } + + // 'Replicated' to 'Tiled'. + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + auto shard_shape = MakePartitionedShape(shape, target); + auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, + MakePartitionOffsets(shape, target, state_.partition_id, state_.b), + shard_shape.dimensions())); + slice->set_sharding(target); + return PartitionedHlo(slice, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::PadWithValue( + HloInstruction* pad_value, absl::Span left_padded_dims) const { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) { + return *this; + } + CHECK(!sharding.IsTileMaximal()); + auto index_shape = ShapeUtil::ChangeElementType(shape, S32); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) { + // Comparison: iota + start_index < valid_size + auto iota = + state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, start_index, {})); + auto index_in_full_shape = + state_.b->AddInstruction(HloInstruction::CreateBinary( + index_shape, HloOpcode::kAdd, iota, broadcast_start_index)); + ComparisonDirection direction = ComparisonDirection::kLt; + int64 index_limit = base_shape_.dimensions(dim); + if (absl::c_linear_search(left_padded_dims, dim)) { + direction = ComparisonDirection::kGe; + index_limit = + index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) - + index_limit; + } + auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(index_limit))); + auto broadcast_limit = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, limit, {})); + return state_.b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_full_shape, broadcast_limit, direction)); + }; + + HloInstruction* mask = nullptr; + auto offsets = MakePartitionOffsets(base_shape_, sharding, + state_.partition_id, state_.b); + for (int64 i = 0; i < shape.rank(); ++i) { + if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) { + continue; + } + if (mask == nullptr) { + mask = get_mask_for_dim(i, offsets[i]); + } else { + mask = state_.b->AddInstruction( + HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask, + get_mask_for_dim(i, offsets[i]))); + } + } + + if (mask == nullptr) { + return *this; + } + + auto broadcast_pad_value = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, pad_value, {})); + auto result = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value)); + result->set_sharding(sharding); + return PartitionedHlo(result, base_shape_, state_); +} + +absl::optional +PartitionedHlo::ReshardAsWindowedInput(const Window& window, + const HloSharding& target, + HloInstruction* pad_value, + bool mask_invalid_region) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache; + for (auto& entry : cache) { + if (std::get<0>(entry) == target && + protobuf_util::ProtobufEquals(std::get<1>(entry), window)) { + return std::get<2>(entry); + } + } + auto update_cache = [&](WindowedInputShardReturnValue result) { + cache.emplace_back(target, window, std::move(result)); + return std::get<2>(cache.back()); + }; + VLOG(2) << "ReshardAsWindowedInput()\n" + << "\twindow:" << window_util::ToString(window) + << "\ttarget sharding:" << target.ToString(); + + CHECK(!target.IsTileMaximal()); + auto partition_ordinals = + MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b); + auto shard_shape = base_shape_; + + std::vector start_on_padded_calculations( + base_shape_.rank()); + std::vector limit_on_padded_calculations( + base_shape_.rank()); + std::vector dynamic_slice_offset_on_output( + base_shape_.rank(), nullptr); + + Window shard_window = window; + auto padded_shape = base_shape_; + std::vector offsets_on_padded_shape(base_shape_.rank()); + std::vector per_shard_window_counts(base_shape_.rank()); + std::vector explicit_left_padding(base_shape_.rank()); + for (int64 i = 0; i < base_shape_.rank(); ++i) { + // Do not pad non-partitioned dimensions. + int64 shard_count = target.tile_assignment().dim(i); + if (shard_count == 1) { + offsets_on_padded_shape[i] = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + continue; + } + const auto& wd = window.dimensions(i); + if (wd.window_dilation() != 1) { + // TODO(yuanzx): Support window dilation. + VLOG(2) << "Failed to reshard window operand due to window dilation"; + return absl::nullopt; + } + int64 full_size = + base_shape_.dimensions(i) + + (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) + + wd.padding_high() + wd.padding_low(); + if (full_size < wd.size()) { + VLOG(2) << "Failed to reshard window operand because the window size is " + "larger than padded base size"; + return absl::nullopt; + } + int64 window_count = (full_size - wd.size()) / wd.stride() + 1; + per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count); + if (wd.stride() != 1 && + (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) { + // TODO(yuanzx): Support this case. + VLOG(2) << "Failed to reshard window operand due to non-trivial dilation"; + return absl::nullopt; + } + + // We use explicit padding for full dilations, then use padding_low and + // padding_high on the sharded op for the remaining. padding_low and + // padding_high are now given initial values, which will be later updated if + // dilation is not 1. + auto swd = shard_window.mutable_dimensions(i); + explicit_left_padding[i] = wd.padding_low() / wd.base_dilation(); + swd->set_padding_low(wd.padding_low() % wd.base_dilation()); + swd->set_padding_high(0); + + // Calculation for the first element needed on the 'padded-but-not-dilated' + // shape. The start on the dilated shape could be a hole, so we add + // wd.base_dilation() - 1 to the constant term to skip the leading holes. + start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation()); + int64 dilated_shard_size = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(), + wd.base_dilation()); + + offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate( + partition_ordinals[i], state_.b); + + auto shard_size_function = + limit_on_padded_calculations[i] - start_on_padded_calculations[i]; + int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count); + shard_shape.set_dimensions(i, max_shard_size); + padded_shape.set_dimensions( + i, limit_on_padded_calculations[i].Calculate(shard_count - 1)); + + // For base dilation, calculate the needed padding_low and padding_high, as + // well as the offset for the output if a dynamic slice is needed after the + // sharded op. + if (wd.base_dilation() != 1) { + // Returns the offset of a shard's first valid element in the dilated + // shard. + auto get_first_valid_element_offset_on_dilated_shard = + [&](int64 shard_ordinal) { + return start_on_padded_calculations[i].Calculate(shard_ordinal) * + wd.base_dilation() + + swd->padding_low() - + wd.stride() * per_shard_window_counts[i] * shard_ordinal; + }; + CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0), + swd->padding_low()); + + // Determine swd->padding_high. + for (int64 shard_ordinal = 0; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 wanted_limit_on_dilated_shard = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + int64 actual_limit_on_dilated_shard_without_pad_high = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal) + + (max_shard_size - 1) * wd.base_dilation() + 1; + swd->set_padding_high(std::max( + swd->padding_high(), + wanted_limit_on_dilated_shard - + actual_limit_on_dilated_shard_without_pad_high)); + } + + // Determine swd->padding_low and output dynamic slice index. + if (wd.stride() == 1) { + int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0); + bool all_same = true; + for (int64 shard_ordinal = 1; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 start = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal); + if (start != swd->padding_low()) { + all_same = false; + } + max_pad_low = std::max(max_pad_low, start); + } + if (!all_same) { + auto start_on_padded_input = + start_on_padded_calculations[i].Calculate(partition_ordinals[i], + state_.b); + // We will calculate + // max_pad_low - (first_window - required_first_window) + // which equals + // required_first_window - (first_window - max_pad_low) + auto first_window_minus_max_pad_low = + MultiplyAddDivideOffsetCalculation( + wd.base_dilation(), swd->padding_low() - max_pad_low, 1) + .Calculate(start_on_padded_input, state_.b); + auto required_first_window = + MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0, + 1) + .Calculate(partition_ordinals[i], state_.b); + dynamic_slice_offset_on_output[i] = + state_.b->AddInstruction(HloInstruction::CreateBinary( + required_first_window->shape(), HloOpcode::kSubtract, + required_first_window, first_window_minus_max_pad_low)); + } + swd->set_padding_low(max_pad_low); + } else { + if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != + 0) { + // General base dilation not yet implemented. + return absl::nullopt; + } + // padding_low on all shards should equal the initially assigned + // swd->padding_low(), i.e., the padding_low() on the original window. + } + } + } + + // Returns the output dynamic slice offset when needed, and absl::nullopt + // otherwise. + auto get_dynamic_slice_offset_on_output_if_needed = + [&]() -> absl::optional> { + if (absl::c_all_of( + dynamic_slice_offset_on_output, + [](HloInstruction* offset) { return offset == nullptr; })) { + return absl::nullopt; + } + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) { + if (dynamic_slice_offset_on_output[i] == nullptr) { + dynamic_slice_offset_on_output[i] = zero; + } + } + return dynamic_slice_offset_on_output; + }; + + // If the currrent HLO is replicated, pad then slice. + if (sharding().IsReplicated()) { + PaddingConfig padding_config; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + // Do not pad non-partitioned dimensions. + if (target.tile_assignment().dim(i) == 1) { + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + continue; + } + padding_config_dim->set_edge_padding_low(explicit_left_padding[i]); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + explicit_left_padding[i] - + base_shape_.dimensions(i)); + } + auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_) + ? hlo_ + : state_.b->AddInstruction(HloInstruction::CreatePad( + padded_shape, hlo_, pad_value, padding_config)); + auto sharded_input = + state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, offsets_on_padded_shape, + shard_shape.dimensions())); + return update_cache(WindowedInputShardReturnValue{ + sharded_input, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); + } + + if (target != sharding()) { + return Reshard(target).ReshardAsWindowedInput(window, target, pad_value); + } + + // Halo exchange. + HloInstruction* visiting_hlo = hlo_; + auto original_shard_shape = MakePartitionedShape(base_shape_, target); + + std::vector left_halo_size_functions(base_shape_.rank()); + std::vector right_halo_size_functions(base_shape_.rank()); + // TODO(yuanzx): We are concatenating on each sharded dimension one at time, + // and in the second dimension (and beyond) we create halos by slicing the + // concat in the previous dimension, which is not optimal. We should generate + // halos only concating slices, instead of slicing concats. + for (int dim = 0; dim < base_shape_.rank(); ++dim) { + int64 shard_count = target.tile_assignment().dim(dim); + if (shard_count == 1) { + continue; + } + int64 input_shard_size = + CeilOfRatio(base_shape_.dimensions(dim), shard_count); + + // Left halo. The size of the halo is derived by subtracting the first read + // element offset of the i'th partition from the limit of the (i-1)'th + // partition. + MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded( + input_shard_size, explicit_left_padding[dim], 1); + left_halo_size_functions[dim] = + shard_limit_of_previous_on_padded - start_on_padded_calculations[dim]; + + // Right halo. + MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded( + input_shard_size, input_shard_size + explicit_left_padding[dim], 1); + right_halo_size_functions[dim] = + limit_on_padded_calculations[dim] - shard_start_of_next_on_padded; + + auto resharded = ExchangeHaloAndGetValidData( + visiting_hlo, base_shape_, left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding[dim], + padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target, + offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim], + state_.collective_ops_creator, state_.next_channel_id, state_.b, + mask_invalid_region); + if (!resharded) { + VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo " + "is beyond the neighbor."; + return Replicate().ReshardAsWindowedInput(window, target, pad_value); + } + visiting_hlo = *resharded; + } + return update_cache(WindowedInputShardReturnValue{ + visiting_hlo, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); +} + +PartitionedHlo PartitionedHlo::Replicate() { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + if (sharding.IsReplicated()) { + return *this; + } + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first.IsReplicated()) { + return entry.second; + } + } + auto update_cache = [&](PartitionedHlo resharded) { + state_.reshard_cache->per_hlo_cache[resharded.hlo()] + .reshard_cache.emplace_back(sharding, *this); + cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); + return cache.back().second; + }; + // 'Single Device' to 'Repliated'. + if (sharding.IsTileMaximal()) { + return update_cache(Broadcast()); + } + + // 'Tiled' to 'Replicated'. + HloInstruction* result = nullptr; + if (state_.collective_ops_creator.create_cross_partition_all_gather) { + result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding, + NewChannel()); + } + Shape padded_base_shape = shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + if (result == nullptr) { + auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); + auto dus = + state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + padded_base_shape, zero_bcast, hlo_, + MakePartitionOffsets(padded_base_shape, sharding, + state_.partition_id, state_.b))); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto all_reduce = + state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, dus, reduction, NewChannel()); + result = all_reduce; + } + if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) { + std::vector start_indices(shape.rank(), 0); + std::vector strides(shape.rank(), 1); + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + base_shape_, result, start_indices, base_shape_.dimensions(), strides)); + } + result->set_sharding(HloSharding::Replicate()); + return update_cache(PartitionedHlo(result, base_shape_, state_)); +} + +PartitionedHlo PartitionedHlo::Broadcast() const { + const Shape& shape = hlo_->shape(); + const HloSharding& sharding = hlo_->sharding(); + CHECK(sharding.HasUniqueDevice()); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(sharding.GetUniqueDevice()))); + Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED); + auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast( + bcast_shape, + state_.b->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id, + ComparisonDirection::kEq)), + {})); + + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, zero, {})); + auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast)); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto result = state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, operand, reduction, NewChannel()); + result->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithAllToAll( + const HloSharding& target) const { + int64 partition_count = sharding().tile_assignment().num_elements(); + absl::optional input_partition_dim = UniqueTiledDim(sharding()); + absl::optional output_partition_dim = UniqueTiledDim(target); + CHECK(input_partition_dim.has_value()); + CHECK(output_partition_dim.has_value()); + + // If the device order is different in the target, fix the order with + // ReshardWithCollectivePermute. + auto input_tile_fixed_device_order = target.tile_assignment(); + input_tile_fixed_device_order.Reshape( + sharding().tile_assignment().dimensions()); + auto input_sharding_fixed_device_order = + HloSharding::Tile(input_tile_fixed_device_order); + if (input_sharding_fixed_device_order != sharding()) { + auto fixed_order = + ReshardWithCollectivePermute(input_sharding_fixed_device_order); + return fixed_order.ReshardWithAllToAll(target); + } + + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + + // The order of ids in the group must follow the target sharding. + std::vector groups(1); + for (int64 device : target.tile_assignment()) { + groups[0].add_replica_ids(device); + } + + HloInstruction* result = nullptr; + + // Split along the split dimension (output_partition_dim) of the all-to-all + // output. + std::vector dimensions; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + if (i == *output_partition_dim) { + dimensions.push_back(partition_count); + dimensions.push_back(padded_hlo->shape().dimensions(i) / partition_count); + } else { + dimensions.push_back(padded_hlo->shape().dimensions(i)); + } + } + auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(base_shape_.element_type(), dimensions), + padded_hlo)); + // After the reshape, it is guaranteed to have at least 3 dimensions. + auto all_to_all = + state_.collective_ops_creator.create_cross_partition_all_to_all( + state_.b, {reshape}, groups, (*state_.next_channel_id)++, + output_partition_dim); + + // Reorder the split dimension of the reshape to be located in front of the + // input partition dimension, so the two dimensions can be combined. + int64 new_input_partition_dim = (*output_partition_dim < *input_partition_dim) + ? *input_partition_dim + 1 + : *input_partition_dim; + std::vector permutation; + for (int64 i = 0; i < all_to_all->shape().rank(); ++i) { + if (i == *output_partition_dim) { + continue; + } + if (i == new_input_partition_dim) { + permutation.push_back(*output_partition_dim); + } + permutation.push_back(i); + } + auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(all_to_all->shape(), permutation) + .ValueOrDie(), + all_to_all, permutation)); + + // Combine the split dimension and the input partition dimension. + auto new_shape = ShapeInference::InferAllToAllShape( + padded_hlo->shape(), *output_partition_dim, + *input_partition_dim, partition_count) + .ValueOrDie(); + result = state_.b->AddInstruction( + HloInstruction::CreateReshape(new_shape, transpose)); + + const Shape result_shape = MakePartitionedShape(base_shape_, target); + if (result_shape != result->shape()) { + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + result_shape, result, std::vector(result_shape.rank(), 0), + result_shape.dimensions(), std::vector(result_shape.rank(), 1))); + } + result->set_sharding(target); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( + const HloSharding& target) const { + CHECK(CanReshardWithCollectivePermute(sharding(), target)); + std::vector> src_dst_pairs; + sharding().tile_assignment().Each( + [&](absl::Span indices, int64 src_device) { + int64 dst_device = target.tile_assignment()(indices); + if (dst_device != src_device) { + src_dst_pairs.emplace_back(src_device, dst_device); + } + }); + auto cp = + state_.collective_ops_creator.create_cross_partition_collective_permute( + state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++); + cp->set_sharding(target); + return PartitionedHlo(cp, base_shape_, state_); +} + +SpmdPartitioningVisitor::SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options, + SpmdPartitioner* partitioner) + : changed_(false), + module_(computation->parent()), + num_partitions_(num_partitions), + num_replicas_(num_replicas), + collective_ops_creator_(collective_ops_creator), + next_channel_id_(next_channel_id), + b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)), + partition_id_(collective_ops_creator_.create_partition_id(&b_)), + logger_(logger), + options_(std::move(options)), + partitioner_(partitioner) {} + +Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { + if (hlo->HasSideEffect()) { + return Unimplemented("Side-effect ops cannot be replicated: %s", + hlo->ToString()); + } + + if (hlo->IsElementwise() && hlo->operand_count() > 0) { + return HandleElementwise(hlo); + } + + if (!hlo->sharding().IsTileMaximal()) { + VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):" + << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + VLOG(1) << " operand " << i + << " sharding:" << hlo->operand(i)->sharding().ToString(); + } + } + + // If the instruction cannot be partitioned, replicate the instruction unless + // the instruction has side-effect. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo()); + } + auto clone = + b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::Replicate()); + clone->set_metadata(hlo->metadata()); + SetPartitionedHlo(hlo, + PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding())); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { + visiting_hlo_ = hlo; + b_.set_visiting_hlo(hlo); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) { + logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(), + b_.derived_instructions(hlo)); + visiting_hlo_ = nullptr; + b_.set_visiting_hlo(nullptr); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + const int64 dimension = hlo->concatenate_dimension(); + if (sharding.tile_assignment().dim(dimension) == 1) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, new_operands)); + }); + return Status::OK(); + } + + // If the concatenate dimension is along one of the partitioned dimensions, + // allocate the full output shape, each partition updates its owned region, + // all-reduce across partitions, and then slice its output region. + + // We currently don't support subgroup all-reduce along partitions, so more + // than 1 partitioned dimensions is not supported. + if (sharding.tile_assignment().dim(dimension) != num_partitions_) { + return DefaultAction(hlo); + } + + // temp_output_shape is the output shape where the concatenate dimension + // is changed to the full (and padded to shard count) dimension size. + auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding); + temp_output_shape.set_dimensions( + dimension, temp_output_shape.dimensions(dimension) * + sharding.tile_assignment().dim(dimension)); + auto temp_output = CreateZero(temp_output_shape, &b_); + + // Offset of each operand along the concatenate dimension. + int64 offset = 0; + for (HloInstruction* operand : hlo->operands()) { + auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo(); + std::vector start_indices( + hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(S32)))); + start_indices[dimension] = + MultiplyAddDivideOffsetCalculation( + spmd_operand->shape().dimensions(dimension), offset, 1) + .Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_, + &b_)[dimension], + &b_); + temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + temp_output_shape, temp_output, spmd_operand, start_indices)); + offset += operand->shape().dimensions(dimension); + } + auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + SetPartitionedHlo(hlo, [&] { + auto start_indices = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + start_indices[dimension] = MultiplyAddDivideOffsetCalculation( + shard_shape.dimensions(dimension), 0, 1) + .Calculate(start_indices[dimension], &b_); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, all_reduce, start_indices, shard_shape.dimensions())); + }); + + return Status::OK(); +} + +// If partitioning in the operand only happens in dimensions in passthrough +// dimensions (offset dimensions in the gather output (or scatter update) that +// have the same size as the operand), returns the corresponding output (or +// update) sharding by passing through the input sharding. +absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( + const PartitionedHlo& operand, const Shape& update_or_gather_shape, + absl::Span collapsed_or_inserted_dims, + absl::Span index_map, + absl::Span offset_or_window_dims, + absl::Span slice_size) { + if (operand.sharding().IsTileMaximal()) { + return operand.sharding(); + } + std::vector passthrough_tile(update_or_gather_shape.rank(), 1); + int64 collapsed = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + int64 dim_partitions = operand.sharding().tile_assignment().dim(i); + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(index_map, i)) { + if (dim_partitions > 1) { + return absl::nullopt; + } + collapsed++; + continue; + } + if (slice_size[i] != operand.base_shape().dimensions(i) && + dim_partitions > 1) { + return absl::nullopt; + } + int64 offset_dim = offset_or_window_dims[i - collapsed]; + if (i - collapsed > 0 && + offset_dim < offset_or_window_dims[i - collapsed - 1]) { + // Output offsets are transposed, we do not support this case. + return absl::nullopt; + } + passthrough_tile[offset_dim] = dim_partitions; + } + Array tile_assignment = operand.sharding().tile_assignment(); + tile_assignment.Reshape(passthrough_tile); + return HloSharding::Tile(tile_assignment); +} + +// Returns whether partitioning in the operand only happens in dimensions with +// gather/scatter slice size 1. +bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + const PartitionedHlo& operand, absl::Span index_map, + absl::Span slice_size, int64 num_partitions) { + if (operand.sharding().IsTileMaximal()) { + return false; + } + int64 trivial_slice_dims_partitions = 1; + for (int64 dim : index_map) { + if (slice_size[dim] == 1) { + trivial_slice_dims_partitions *= + operand.sharding().tile_assignment().dim(dim); + } + } + return trivial_slice_dims_partitions == num_partitions; +} + +// Returns the min and max for the indices (replicated) in a scatter/gather +// which has the operand partitioned on trivial slice dimensions (slice size 1). +std::pair +IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + const PartitionedHlo& operand, const PartitionedHlo& replicated_indices, + HloInstruction* partition_id, absl::Span index_map, + int64 index_vector_dim, SpmdBuilder* b) { + auto operand_offsets = MakePartitionOffsets( + operand.base_shape(), operand.sharding(), partition_id, b); + // Find the per-dimension index bounds. + std::vector min_indices; + std::vector max_indices; + for (int64 i = 0; i < index_map.size(); ++i) { + int64 dim = index_map[i]; + int64 partitions = operand.sharding().tile_assignment().dim(dim); + if (partitions == 1) { + min_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), 0, b)); + max_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), + operand.base_shape().dimensions(dim), b)); + continue; + } + auto offset = operand_offsets[dim]; + if (offset->shape().element_type() != + replicated_indices.base_shape().element_type()) { + offset = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(replicated_indices.base_shape().element_type(), + {}), + offset)); + } + min_indices.push_back(offset); + auto partition_size_minus_1 = + CreateR0WithType(replicated_indices.base_shape().element_type(), + operand.hlo()->shape().dimensions(dim) - 1, b); + max_indices.push_back(b->AddInstruction(HloInstruction::CreateBinary( + offset->shape(), HloOpcode::kAdd, offset, partition_size_minus_1))); + } + // Broadcast the index bounds to the same shape as the indices. + HloInstruction* broadcast_min; + HloInstruction* broadcast_max; + if (index_vector_dim < replicated_indices.base_shape().rank()) { + // The index vector is an R1, we need to reshape individual bounds to + // [1], and concat them if there are more than one. + for (int64 i = 0; i < min_indices.size(); ++i) { + min_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(min_indices[i]->shape().element_type(), {1}), + min_indices[i])); + max_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(max_indices[i]->shape().element_type(), {1}), + max_indices[i])); + } + int64 slice_dims = max_indices.size(); + if (slice_dims > 1) { + min_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(min_indices[0]->shape().element_type(), + {slice_dims}), + min_indices, 0)); + max_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + min_indices[0]->shape(), max_indices, 0)); + } + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {index_vector_dim})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {index_vector_dim})); + } else { + CHECK_EQ(max_indices.size(), 1); + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {})); + } + return {broadcast_min, broadcast_max}; +} + +Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { + auto scatter = Cast(hlo); + auto dnums = scatter->scatter_dimension_numbers(); + auto operand = GetPartitionedHlo(scatter->operand(0)); + auto indices = GetPartitionedHlo(scatter->operand(1)); + auto updates = GetPartitionedHlo(scatter->operand(2)); + std::vector slice_size(operand.base_shape().rank(), 1); + int64 num_update_window_dims = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + continue; + } + slice_size[i] = updates.base_shape().dimensions( + dnums.update_window_dims(num_update_window_dims++)); + } + std::vector inserted_window_dims(dnums.inserted_window_dims().begin(), + dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, updates.base_shape(), inserted_window_dims, + scatter_dims_to_operand_dims, update_window_dims, slice_size); + // Handle pass through cases if we can use compatible sharding for update. + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(*maybe_passthrough); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), indices.hlo(), updates.hlo(), + scatter->to_apply(), dnums, scatter->indices_are_sorted(), + scatter->unique_indices())); + pscatter->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, scatter_dims_to_operand_dims, slice_size, + num_partitions_) && + ShapeSizeInBytes(updates.base_shape()) < + ShapeSizeInBytes(scatter->shape())) { + // Operand is sharded on trivial slice dims (update slice size 1). We can + // adjust the indices on each partition by subtracting the offsets. Then + // we execute a scatter on full updated indices, and out-of-bound accesses + // will have no effect on the result as guaranteed by the scatter + // semantics. + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(HloSharding::Replicate()); + HloInstruction* indices_min; + HloInstruction* indices_max_unused; + std::tie(indices_min, indices_max_unused) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, scatter_dims_to_operand_dims, + dnums.index_vector_dim(), &b_); + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + indices_min)); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), adjusted_indices, + updates.hlo(), scatter->to_apply(), dnums, + scatter->indices_are_sorted(), scatter->unique_indices())); + pscatter->set_sharding(operand.sharding()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding); + + // Create a window config to represent the slice. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(hlo->slice_strides(i)); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_padding_low(-hlo->slice_starts(i)); + dim->set_padding_high(hlo->slice_limits(i) - + hlo->operand(0)->shape().dimensions(i)); + dim->set_base_dilation(1); + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + const Shape& operand_shape = reshard_operand->sharded_input->shape(); + + std::vector start_indices = hlo->slice_starts(); + std::vector limit_indices = hlo->slice_limits(); + std::vector strides = hlo->slice_strides(); + bool need_slice = false; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + auto dim = reshard_operand->shard_window.dimensions(i); + start_indices[i] = -dim.padding_low(); + limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high(); + if (start_indices[i] != 0 || strides[i] != 1 || + limit_indices[i] != operand_shape.dimensions(i)) { + need_slice = true; + } + } + + SetPartitionedHlo(hlo, [&] { + if (need_slice) { + auto shard_shape = MakePartitionedShape(hlo->shape(), sharding); + return b_.AddInstruction(HloInstruction::CreateSlice( + shard_shape, reshard_operand->sharded_input, start_indices, + limit_indices, strides)); + } + return reshard_operand->sharded_input; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { + HloSharding sharding = hlo->sharding(); + // Special handling for sort in TopK when first operand partitioined at + // sort dimension. + auto k = GetKValueInTopKWhenPartitionSortDim(hlo); + if (k.has_value()) { + // When the first operand partitioned at sort dimension: + // 1. Partition sort computation to different partitions; + // 2. Slice TopK value and index from different partitions; + // 3. Gather and replicate value and index from different partitions, + // the shape of replicated value and index will be + // [batch_size, ..., partition_count * k, ...]; + // 4. Final sort uses replicated value and index from different partitions + // as input. + // GetTupleElement and Slice after the non-partitoned sort won't change + // at this point, as HandleGetTupleElement and HandleSlice will update them. + HloSortInstruction* sort = DynCast(hlo); + const int64 sort_dim = sort->sort_dimension(); + auto input = hlo->operand(0); + auto index = hlo->operand(1); + const HloSharding& input_sharding = input->sharding(); + const int64 partition_count = + input_sharding.tile_assignment().dim(sort_dim); + const int64 input_size = input->shape().dimensions(sort_dim); + const int64 per_partition_size = CeilOfRatio(input_size, partition_count); + const auto element_type = input->shape().element_type(); + const auto index_type = index->shape().element_type(); + + // Partition and pad input and index. + // Pad input with minimal value. + auto partitioned_input = GetPartitionedHlo(input).PadWithValue( + CreateFirstWithType(element_type, &b_)); + // Pad index with max value. + auto partitioned_index = + GetPartitionedHlo(index) + .Reshard(input_sharding) + .PadWithValue(CreateLastWithType(index_type, &b_)); + + // Each partition needs to do TopK separately, thus the base shape + // becomes the padded shape. + std::vector replicated_dimensions( + input->shape().dimensions().begin(), input->shape().dimensions().end()); + replicated_dimensions[sort_dim] = per_partition_size * partition_count; + const Shape replicated_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(element_type, replicated_dimensions), + ShapeUtil::MakeShape(index_type, replicated_dimensions)}); + + // Partition original topk to different shards. + auto topk_sharding = + input_sharding.GetTupleSharding(replicated_shape).ValueOrDie(); + auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding); + auto topk = b_.AddInstruction(hlo->CloneWithNewOperands( + shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()})); + + // Get value from first sort. + HloInstruction* value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(0), topk, 0)); + HloInstruction* index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + + // Slice top K value from the first partitioned sort. + replicated_dimensions[sort_dim] = k.value() * partition_count; + auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value()); + slice_input->set_sharding(input_sharding); + PartitionedHlo partitioned_slice_input( + slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions), + MakePartitioningState()); + // Reshard value to be replicated. + auto replicated_slice_input = + partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo(); + + // Slice top K index from the first parttioned sort. + auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value()); + slice_index->set_sharding(input_sharding); + PartitionedHlo partitioned_slice_index( + slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions), + MakePartitioningState()); + // Reshard value to be replicated. + auto replicated_slice_index = + partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo(); + + // Creates replicated sort to do TopK, the input is value and index pairs + // from all the partitions. + const Shape final_topk_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(element_type, replicated_dimensions), + ShapeUtil::MakeShape(index_type, replicated_dimensions)}); + auto final_sort = b_.AddInstruction(HloInstruction::CreateSort( + final_topk_shape, sort_dim, + {replicated_slice_input, replicated_slice_index}, sort->to_apply(), + sort->is_stable())); + final_sort->set_sharding(HloSharding::Replicate() + .GetTupleSharding(final_sort->shape()) + .ValueOrDie()); + PartitionedHlo replicated_sort(final_sort, final_topk_shape, + MakePartitioningState()); + SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding())); + + return Status::OK(); + } + + if (hlo->shape().IsTuple()) { + // Check that all elements are sharded in the same way. + if (hlo->shape().tuple_shapes_size() == 0) { + return DefaultAction(hlo); + } + sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) { + return DefaultAction(hlo); + } + } + } + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 dim : hlo->dimensions()) { + if (sharding.tile_assignment().dim(dim) > 1) { + return DefaultAction(hlo); + } + } + // Reshard operands to the same as the output. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "SPMDFullToShardShape") { + // This op switches from auto partitioning to manual partitioning. + auto input_partitioned = GetPartitionedHlo(hlo->operand(0)); + if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) { + input_partitioned = input_partitioned.PadWithValue( + CreateR0WithType(hlo->shape().element_type(), 0, &b_)); + } + auto input = input_partitioned.hlo(); + CHECK(hlo->sharding().IsReplicated()); + CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape())); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() == "SPMDShardToFullShape") { + // This op switches from manual partitioning to auto partitioning. + auto input = GetPartitionedHlo(hlo->operand(0)).hlo(); + CHECK(input->sharding().IsReplicated()); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + CHECK(ShapeUtil::Compatible( + copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding()))); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() != "TopK") { + return DefaultAction(hlo); + } + + if (!hlo->operand(0)->has_sharding()) { + return DefaultAction(hlo); + } + + const HloSharding& sharding = hlo->operand(0)->sharding(); + if (sharding.IsTileMaximal() || sharding.IsReplicated()) { + return DefaultAction(hlo); + } + + const int64 sort_dim = 1; + const int64 shard_count = sharding.tile_assignment().dim(sort_dim); + + if (shard_count <= 1) { + return DefaultAction(hlo); + } + + const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim); + const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0); + const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim); + const int64 per_partition_size = CeilOfRatio(input_size, shard_count); + + if (k >= per_partition_size) { + return DefaultAction(hlo); + } + + auto input = hlo->operand(0); + const auto element_type = input->shape().element_type(); + + auto partitioned_input = GetPartitionedHlo(input).PadWithValue( + CreateFirstWithType(element_type, &b_)); + + // Each partition needs to do TopK separately, thus the base shape + // becomes [batch_size, k * shard_count]. + const Shape replicated_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(), + {batch_size, k * shard_count}), + ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})}); + auto custom_call_sharding = + sharding.GetTupleSharding(replicated_shape).ValueOrDie(); + auto shard_shape = + MakePartitionedShape(replicated_shape, custom_call_sharding); + auto topk = b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()})); + topk->set_sharding(custom_call_sharding); + // Partition customcall. + PartitionedHlo partitioned_topk(topk, replicated_shape, + MakePartitioningState()); + topk = partitioned_topk.hlo(); + + // Get value from TopK. + HloInstruction* value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(0), topk, 0)); + value_gte->set_sharding(sharding); + // Partition GetTupleElement of value. + PartitionedHlo value_partitioned_gte( + value_gte, partitioned_topk.base_shape().tuple_shapes(0), + MakePartitioningState()); + // Reshard value to be replicated. + auto replicated_value_gte = + value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Get index from TopK. + HloInstruction* index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()), + partition_id_)); + // Add per partition offset to index, index returned from CustomCall always + // starts from 0. + auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast( + index_gte->shape(), + b_.AddInstruction(HloInstruction::CreateBinary( + partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32, + b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(per_partition_size))))), + {})); + index_gte = b_.AddInstruction(HloInstruction::CreateBinary( + index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset)); + index_gte->set_sharding(sharding); + // Parttion GetTupleElement of index. + PartitionedHlo index_partitioned_gte( + index_gte, partitioned_topk.base_shape().tuple_shapes(1), + MakePartitioningState()); + // Reshard index to be replicated. + auto replicated_index_gte = + index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Creates replicated sort to do TopK, the input is value and index pairs + // from all the partitions. The reason to use Sort instead of CustomCall TopK + // is CustomCall only takes value as input. There will be an extra Gather + // to get the correct index if CustomCall is used here. + + // Create comparator for the sort. + XlaBuilder b("Sort.Compare"); + XlaComputation comparator = CreateScalarComparisonComputation( + "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt}, + &b); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(comparator.proto(), config)); + HloCloneContext context(module_); + auto compare_computation = + module_->DeepCloneComputation(new_module->entry_computation(), &context); + auto sort = b_.AddInstruction(HloInstruction::CreateSort( + replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte}, + compare_computation, true)); + sort->set_sharding( + HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie()); + PartitionedHlo replicated_sort(sort, replicated_shape, + MakePartitioningState()); + + // Slice value and index from top-k for output. + HloInstruction* sort_value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(), + 0)); + HloInstruction* sort_index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(), + 1)); + // Slice value from final sort. + HloInstruction* slice_sort_value = + SliceFirstK(sort_value_gte, &b_, sort_dim, k); + // Slice index from final sort. + HloInstruction* slice_index_value = + SliceFirstK(sort_index_gte, &b_, sort_dim, k); + auto create_tuple = b_.AddInstruction( + HloInstruction::CreateTuple({slice_sort_value, slice_index_value})); + create_tuple->set_sharding(HloSharding::Replicate()); + + SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(), + MakePartitioningState()) + .Reshard(hlo->sharding())); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + std::vector inverse_dimensions(hlo->shape().rank()); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + inverse_dimensions[hlo->dimensions(i)] = i; + } + auto desired_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions); + + auto operand = GetPartitionedHlo(hlo->operand(0)) + .Reshard(desired_operand_sharding) + .hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)); + // The output shape is the source and the operand shape is the target to get + // the aligned sharding for the operand. + auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding( + hlo->shape(), hlo->operand(0)->shape(), hlo->sharding()); + if (desired_operand_sharding.has_value()) { + auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo})); + }); + return Status::OK(); + } + + // Try use halo exchange for certain split-dim/merge-dims cases. + // ReshapeSharding failed in these cases probably due to uneven partitioning, + // where halo exchange could help. Specifically we check the following + // conditions to detect supported cases: + // 1) Both input and output are partitioned on one dimension. + // 2) The combined size of dimensions before the partitioned dimension are the + // same on input and output. This means we don't need to consider the major + // dimensions. + // 3) Let A = the input size on the partitioned dimension, and + // B = the output size on the partitioned dimension; then + // either A % B == 0 (split dim) or B % A == 0 (merge dims). + auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding()); + auto maybe_output_sharded_dim = UniqueTiledDim(sharding); + if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) { + return DefaultAction(hlo); + } + int64 input_sharded_dim = *maybe_input_sharded_dim; + int64 output_sharded_dim = *maybe_output_sharded_dim; + // Check that the major dims before the sharded dim have the same total size + // for input and output. + int64 input_major_dims_size = 1; + for (int64 i = 0; i < input_sharded_dim; ++i) { + input_major_dims_size *= operand.base_shape().dimensions(i); + } + int64 output_major_dims_size = 1; + for (int64 i = 0; i < output_sharded_dim; ++i) { + output_major_dims_size *= hlo->shape().dimensions(i); + } + if (input_major_dims_size != output_major_dims_size) { + return DefaultAction(hlo); + } + // Fix potential device ordering mismatch in tile assignment. + Array new_input_tile_assignment = sharding.tile_assignment(); + new_input_tile_assignment.Reshape( + operand.sharding().tile_assignment().dimensions()); + operand = operand.Reshard(HloSharding::Tile(new_input_tile_assignment)); + + int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim); + int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim); + auto input_shard_shape = + MakePartitionedShape(operand.base_shape(), operand.sharding()); + auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding); + if (input_dim_size % output_dim_size == 0) { + // Split dim. + int64 split_factor = input_dim_size / output_dim_size; + int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim); + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == input_sharded_dim) { + dim->set_padding_high(output_shard_size * split_factor * + num_partitions_ - + input_dim_size); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, operand.sharding(), + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_operand->sharded_input->shape().dimensions(input_sharded_dim), + output_shard_size * split_factor); + SetPartitionedHlo(hlo, [&] { + // Do a local reshape. + return b_.AddInstruction(HloInstruction::CreateReshape( + output_shard_shape, reshard_operand->sharded_input)); + }); + return Status::OK(); + } else if (output_dim_size % input_dim_size == 0) { + // Merge dims. + int64 merge_factor = output_dim_size / input_dim_size; + // First reshape locally. (The sharded dimension could include padded data.) + auto tmp_shard_shape = output_shard_shape; + tmp_shard_shape.set_dimensions( + output_sharded_dim, + input_shard_shape.dimensions(input_sharded_dim) * merge_factor); + auto tmp_reshape = b_.AddInstruction( + HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo())); + tmp_reshape->set_metadata(hlo->metadata()); + tmp_reshape->set_sharding(hlo->sharding()); + auto tmp_full_shape = tmp_shard_shape; + tmp_full_shape.set_dimensions( + output_sharded_dim, + tmp_shard_shape.dimensions(output_sharded_dim) * num_partitions_); + auto tmp_output = + PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState()); + + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == output_sharded_dim) { + dim->set_padding_high(output_dim_size - + tmp_shard_shape.dimensions(output_sharded_dim) * + num_partitions_); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_output = tmp_output.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_output.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_output->sharded_input->shape().dimensions(output_sharded_dim), + output_shard_shape.dimensions(output_sharded_dim)); + SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&] { + int64 dimension = Cast(hlo)->iota_dimension(); + auto iota = b_.AddInstruction(HloInstruction::CreateIota( + MakePartitionedShape(hlo->shape(), sharding), dimension)); + + if (sharding.tile_assignment().dim(dimension) > 1) { + auto partition_ordinals = + MakeTiledPartitionOrdinals(sharding, partition_id_, &b_); + auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(iota->shape().dimensions(dimension)))); + auto offset = b_.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, + partition_ordinals[dimension], multiplier)); + if (iota->shape().element_type() != S32) { + offset = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset)); + } + auto broadcast = b_.AddInstruction( + HloInstruction::CreateBroadcast(iota->shape(), offset, {})); + return b_.AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, broadcast)); + } + + return iota; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + int64 device = hlo->sharding().GetUniqueDevice(); + const HloSharding sharding = HloSharding::AssignDevice(device); + + std::vector operands; + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + operand_shapes.push_back(operand->shape()); + } + auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands)); + auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes); + + auto on_device = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(device))); + auto pred = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device, + ComparisonDirection::kEq)); + + SpmdBuilder true_b("true_computation", visiting_hlo_); + HloComputation* true_computation; + { + auto param = true_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "true_branch_param")); + std::vector new_operands; + for (int64 i = 0; i < operands.size(); ++i) { + new_operands.push_back(true_b.AddInstruction( + HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i))); + } + auto root = true_b.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + true_computation = module_->AddEmbeddedComputation(true_b.Build(root)); + } + + SpmdBuilder false_b("false_computation", visiting_hlo_); + HloComputation* false_computation; + { + false_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "false_branch_param")); + auto root = CreateZero(hlo->shape(), &false_b); + false_computation = module_->AddEmbeddedComputation(false_b.Build(root)); + } + + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + hlo->shape(), pred, operand, true_computation, operand, + false_computation)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) { + if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) { + return HandleElementwise(hlo); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto& operand = GetPartitionedHlo(hlo->operand(0)); + + // Tiled output. + std::vector wanted_input_tile_size(operand.base_shape().rank()); + std::vector sharded_new_dims; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + wanted_input_tile_size[i] = + hlo->sharding().tile_assignment().dim(hlo->dimensions(i)); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i) && + hlo->sharding().tile_assignment().dim(i) > 1) { + sharded_new_dims.push_back(i); + } + } + if (sharded_new_dims.empty()) { + // The new dimensions are replicated, so that we can do the adjustment on + // the input. + Array wanted_input_tile_assignment(wanted_input_tile_size); + wanted_input_tile_assignment.Each( + [&](absl::Span indices, int64* val) { + std::vector indices_in_broadcast(hlo->shape().rank(), 0); + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + indices_in_broadcast[hlo->dimensions(i)] = indices[i]; + } + *val = hlo->sharding().tile_assignment()(indices_in_broadcast); + }); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + {operand.Reshard(HloSharding::Tile(wanted_input_tile_assignment)) + .hlo()})); + }); + } else { + auto input = operand.Reshard(HloSharding::Replicate()).hlo(); + // We pad and shard the input first, then broadcast to the final shard + // shape. + auto output_offsets = + MakePartitionOffsets(hlo->shape(), hlo->sharding(), partition_id_, &b_); + std::vector input_offsets(operand.base_shape().rank()); + auto output_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto input_shard_shape = input->shape(); + auto padded_input_shape = input->shape(); + for (int64 i = 0; i < input_offsets.size(); ++i) { + input_offsets[i] = output_offsets[hlo->dimensions(i)]; + input_shard_shape.set_dimensions( + i, output_shard_shape.dimensions(hlo->dimensions(i))); + padded_input_shape.set_dimensions( + i, hlo->sharding().tile_assignment().dim(hlo->dimensions(i)) * + input_shard_shape.dimensions(i)); + } + auto padded_input = PadToShape(input, padded_input_shape, &b_); + auto input_shard = + ShapeUtil::Compatible(input_shard_shape, padded_input->shape()) + ? padded_input + : b_.AddInstruction(HloInstruction::CreateDynamicSlice( + input_shard_shape, padded_input, input_offsets, + input_shard_shape.dimensions())); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(output_shard_shape, {input_shard})); + }); + } + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) { + const Literal& literal = hlo->literal(); + if (literal.shape().IsTuple() || + (!hlo->sharding().IsTileMaximal() && + (!EvenlyPartitions(hlo->shape(), hlo->sharding()) || + !literal.IsAllFirst()))) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + std::vector start_indices(hlo->shape().rank(), 0); + auto constant = b_.AddInstruction(HloInstruction::CreateConstant( + literal.Slice(start_indices, shard_shape.dimensions()))); + *constant->mutable_shape() = shard_shape; + return constant; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) || + !hlo->operand(i + 1)->IsConstant() || + !hlo->operand(i + 1)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + partitioned_shape, new_input, new_indices, + partitioned_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i) || + !hlo->operand(i + 2)->IsConstant() || + !hlo->operand(i + 2)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + auto new_update = + GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + partitioned_shape, new_input, new_update, new_indices)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { + auto gather = Cast(hlo); + const auto& dnums = gather->gather_dimension_numbers(); + auto operand = GetPartitionedHlo(gather->operand(0)); + auto indices = GetPartitionedHlo(gather->operand(1)); + std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), + dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, gather->shape(), collapsed_slice_dims, start_index_map, + offset_dims, gather->gather_slice_sizes()); + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough); + std::vector pslice_sizes(gather->gather_slice_sizes().begin(), + gather->gather_slice_sizes().end()); + for (int64 i = 0; i < pslice_sizes.size(); ++i) { + if (operand.sharding().tile_assignment().dim(i) > 1) { + pslice_sizes[i] = operand.hlo()->shape().dimensions(i); + } + } + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes, + gather->indices_are_sorted())); + pgather->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, start_index_map, gather->gather_slice_sizes(), + num_partitions_) && + ShapeSizeInBytes(gather->shape()) < + ShapeSizeInBytes(gather->operand(0)->shape())) { + indices = indices.Reshard(HloSharding::Replicate()); + // Now the operand is partitioned in trivial slice dimensions, and the + // indices are replicated. We execute a gather on partitioned operand, + // with full number of indices, where out-of-bounds indices are clamped, + // and masked out with 0 in the result; then we use all-reduce to combine + // results. Although gather will not get faster, we avoided the need to + // replicate the operand. + HloInstruction* indices_min; + HloInstruction* indices_max; + std::tie(indices_min, indices_max) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, start_index_map, + dnums.index_vector_dim(), &b_); + // Clamp the indices. + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary( + indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(), + indices_max)); + // Adjust the indices by subtracting the offset. + adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.base_shape(), HloOpcode::kSubtract, adjusted_indices, + indices_min)); + // Gather on adjusted indices. + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + gather->shape(), operand.hlo(), adjusted_indices, dnums, + gather->gather_slice_sizes(), gather->indices_are_sorted())); + // Mask out invalid results. + auto filter = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_min, ComparisonDirection::kLt)); + filter = b_.AddInstruction(HloInstruction::CreateBinary( + filter->shape(), HloOpcode::kOr, filter, + b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_max, ComparisonDirection::kGt)))); + if (dnums.index_vector_dim() < indices.base_shape().rank()) { + std::vector reduced_filter_dims; + for (int64 i = 0; i < filter->shape().rank(); ++i) { + if (i != dnums.index_vector_dim()) { + reduced_filter_dims.push_back(filter->shape().dimensions(i)); + } + } + filter = b_.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter, + CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()}, + MakeBinaryAdd(PRED, module_))); + } + std::vector batch_dims; + for (int64 i = 0; i < pgather->shape().rank(); ++i) { + if (!absl::c_linear_search(dnums.offset_dims(), i)) { + batch_dims.push_back(i); + } + } + auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter, + batch_dims)); + auto filtered = b_.AddInstruction(HloInstruction::CreateTernary( + pgather->shape(), HloOpcode::kSelect, broadcast_filter, + CreateZero(pgather->shape(), &b_), pgather)); + // Combine from different partitions. + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, filtered, + MakeBinaryAdd(filtered->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) { + const auto& tuple = GetPartitionedHlo(hlo->operand(0)); + auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()), + tuple.hlo(), hlo->tuple_index())); + SetPartitionedHlo(hlo, [&]() { + const auto source_sharding = tuple.sharding().GetSubSharding( + tuple.base_shape(), {hlo->tuple_index()}); + gte->set_sharding(source_sharding); + PartitionedHlo source_partitioned_gte(gte, hlo->shape(), + MakePartitioningState()); + return source_partitioned_gte.Reshard(hlo->sharding()).hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) { + const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0); + auto token = GetPartitionedHlo(hlo->operand(0)).hlo(); + if (ShapeUtil::GetLeafCount(shape) == 0) { + // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it + // requires one element for an empty tuple, but leaf-count number of + // elements for non-empty tuple. So if it has a nested empty tuple, we + // cannot invoke GetSubSharding() since it expects a sharding for the empty + // tuple. This is a workaround for that case. + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction( + HloInstruction::CreateInfeed(shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + auto shard_shape = MakePartitionedShape(shape, sharding); + if (EvenlyPartitions(shape, sharding)) { + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateInfeed( + shard_shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + // Create a branch for each unique partitioned shape. + std::vector per_branch_partitioned_shapes; + std::vector conditional_branch_indices(num_partitions_); + for (int64 i = 0; i < num_partitions_; ++i) { + auto partitioned_shape = + MakeNonPaddedShapeForGivenPartition(shape, sharding, i); + int64 matching_existing_index = 0; + for (; matching_existing_index < per_branch_partitioned_shapes.size(); + ++matching_existing_index) { + if (ShapeUtil::Compatible( + partitioned_shape, + per_branch_partitioned_shapes[matching_existing_index])) { + break; + } + } + if (matching_existing_index < per_branch_partitioned_shapes.size()) { + conditional_branch_indices[i] = matching_existing_index; + } else { + conditional_branch_indices[i] = per_branch_partitioned_shapes.size(); + per_branch_partitioned_shapes.push_back(std::move(partitioned_shape)); + } + } + + HloInstruction* branch_index; + if (per_branch_partitioned_shapes.size() == num_partitions_) { + // Use partition ID as the branch index if each partition has its own + // branch. + branch_index = partition_id_; + // PartitionId's output is U32 but conditional requires S32. + if (branch_index->shape().element_type() != S32) { + branch_index = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(branch_index->shape(), S32), + branch_index)); + } + } else { + // Otherwise, use a constant table to look up the branch index. + auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(conditional_branch_indices))); + branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_}, + {1})); + branch_index = b_.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), branch_index)); + } + + std::vector branches(per_branch_partitioned_shapes.size()); + for (int64 i = 0; i < branches.size(); ++i) { + SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_); + auto param = branch_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, token->shape(), "infeed_token_param")); + auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed( + per_branch_partitioned_shapes[i], param, hlo->infeed_config())); + branches[i] = module_->AddEmbeddedComputation(branch_b.Build(infeed)); + if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) { + TF_ASSIGN_OR_RETURN( + auto padded, + branches[i]->DeepCopyInstructionWithCustomCopier( + infeed, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + // Index {1} corresponds to the token. + if (leaf_index.empty() || leaf_index[0] != 0) { + return leaf; + } + ShapeIndexView subindex(leaf_index, 1); + if (ShapeUtil::Compatible( + ShapeUtil::GetSubshape(per_branch_partitioned_shapes[i], + subindex), + ShapeUtil::GetSubshape(shard_shape, subindex))) { + return leaf; + } + return PadToShape(leaf, + ShapeUtil::GetSubshape(shard_shape, subindex), + nullptr, comp); + })); + branches[i]->set_root_instruction(padded, + /*accept_different_shape=*/true); + } + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index, + branches, std::vector(branches.size(), token))); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + auto lhs = GetPartitionedHlo(hlo->operand(0)); + // Create a window config to represent the pad. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + const auto& pd = hlo->padding_config().dimensions(i); + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_padding_low(pd.edge_padding_low()); + dim->set_padding_high(pd.edge_padding_high()); + dim->set_base_dilation(pd.interior_padding() + 1); + } + + auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + auto reshard_operand = + lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs, + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + PaddingConfig sharded_padding_config; + bool need_pad = false; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + auto dim = sharded_padding_config.add_dimensions(); + const auto& wd = reshard_operand->shard_window.dimensions(i); + dim->set_edge_padding_low(wd.padding_low()); + dim->set_edge_padding_high(wd.padding_high()); + dim->set_interior_padding(wd.base_dilation() - 1); + if (wd.padding_low() != 0 || wd.padding_high() != 0 || + wd.base_dilation() != 1) { + need_pad = true; + } + } + auto sharded_pad = reshard_operand->sharded_input; + if (need_pad) { + TF_ASSIGN_OR_RETURN(auto sharded_pad_shape, + ShapeInference::InferPadShape(sharded_pad->shape(), + replicated_rhs->shape(), + sharded_padding_config)); + sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape, + sharded_pad, replicated_rhs, + sharded_padding_config)); + } + + SetPartitionedHlo(hlo, [&]() { + if (!reshard_operand->dynamic_slice_index_on_output) { + return sharded_pad; + } + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_pad, + *reshard_operand->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) { + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto new_param = b_.AddInstruction(HloInstruction::CreateParameter( + hlo->parameter_number(), shard_shape, "param")); + if (hlo->parameter_replicated_at_leaf_buffers()) { + new_param->set_parameter_replicated_at_leaf_buffers( + *hlo->parameter_replicated_at_leaf_buffers()); + } + return new_param; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { + int64 input_count = 1; + auto per_input_sharding = hlo->sharding(); + if (hlo->shape().IsTuple()) { + input_count = hlo->shape().tuple_shapes_size(); + CHECK_GT(input_count, 0); + per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + } + + std::vector inputs; + std::vector inits; + for (int64 operand_id = 0; operand_id < input_count; ++operand_id) { + inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count)) + .Reshard(HloSharding::Replicate()) + .hlo()); + inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); + if (operand_id > 0) { + // Make sure all operands are sharded in the same way. + inputs.back() = inputs.back().Reshard(inputs[0].sharding()); + } + if (!inputs[0].sharding().IsTileMaximal()) { + inputs.back() = inputs.back().PadWithValue(inits[operand_id]); + } + } + bool reduce_sharded_dimension = false; + if (!inputs[0].sharding().IsTileMaximal()) { + reduce_sharded_dimension = absl::c_any_of(hlo->dimensions(), [&](int64 i) { + return inputs[0].sharding().tile_assignment().dim(i) > 1; + }); + + // reduce_sharded_dimension is not supported for tuple-shaped reduces. + if (reduce_sharded_dimension && input_count > 1) { + return DefaultAction(hlo); + } + + // Currently we only support reducing all or none of the sharded + // dimensions. + if (reduce_sharded_dimension) { + for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { + if (inputs[0].sharding().tile_assignment().dim(i) > 1 && + absl::c_count(hlo->dimensions(), i) == 0) { + return DefaultAction(hlo); + } + } + } + } + + std::vector new_operand_shapes(input_count * 2); + for (int64 i = 0; i < input_count; ++i) { + new_operand_shapes[i] = inputs[i].hlo()->mutable_shape(); + new_operand_shapes[i + input_count] = inits[i]->mutable_shape(); + } + // Create the shard shape of the reduce result. + TF_ASSIGN_OR_RETURN( + auto reduce_shape, + ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(), + hlo->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = hlo->shape().layout(); + + std::vector input_hlos(input_count); + for (int64 i = 0; i < input_count; ++i) { + input_hlos[i] = inputs[i].hlo(); + } + auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply())); + local_reduce->set_metadata(hlo->metadata()); + + SetPartitionedHlo(hlo, [&]() { + HloInstruction* reduce; + if (reduce_sharded_dimension) { + CHECK(local_reduce->shape().IsArray()); + reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, local_reduce, hlo->to_apply(), NewChannel()); + reduce->set_sharding(HloSharding::Replicate()); + } else { + reduce = local_reduce; + if (inputs[0].sharding().IsTileMaximal()) { + reduce->set_sharding(inputs[0].sharding()); + } else { + // Remove tile assignment dimensions that are reduced. + std::vector tile_dimensions; + for (int64 i = 0; i < input_hlos[0]->shape().rank(); ++i) { + if (absl::c_count(hlo->dimensions(), i) == 0) { + tile_dimensions.push_back( + inputs[0].sharding().tile_assignment().dim(i)); + } + } + Array new_tile = inputs[0].sharding().tile_assignment(); + new_tile.Reshape(tile_dimensions); + auto sharding = HloSharding::Tile(new_tile); + if (input_count > 1) { + std::vector tuple(input_count, sharding); + sharding = HloSharding::Tuple(hlo->shape(), tuple); + } + reduce->set_sharding(sharding); + } + } + + return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) { + auto reverse = Cast(hlo); + if (reverse->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + auto operand = GetPartitionedHlo(reverse->operand(0)) + .Reshard(hlo_sharding_util::ReverseSharding( + reverse->sharding(), reverse->dimensions())); + auto left_padded_operand = + HaloExchangeToPadOnLeft(operand, reverse->dimensions()); + if (!left_padded_operand) { + return DefaultAction(hlo); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + left_padded_operand->shape(), {left_padded_operand})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + + // Shardings for the body parameter, body root, and cond parameter must be + // the same, and the condition root must be replicated so that all partitions + // follow the same control flow. + hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding); + hlo->while_body()->parameter_instruction(0)->set_sharding(sharding); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_condition(), + HloSharding::Replicate(), + next_channel_id_, logger_) + .status()); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_body(), sharding, + next_channel_id_, logger_) + .status()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateWhile( + MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(), + hlo->while_body(), + GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) { + std::vector branch_args; + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + + // Shardings of the branch computation parameter and its argument must be + // the same. + computation->parameter_instruction(0)->set_sharding( + hlo->operand(i + 1)->sharding()); + branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo()); + } + + // The root of the branch computations must follow the sharding of the + // conditional instruction. + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(computation, hlo->sharding(), + next_channel_id_, logger_) + .status()); + } + + // We replicate the predicate of the conditional (the first operand) so that + // all partitions follow the same control flow. + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateConditional( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + GetPartitionedHlo(hlo->operand(0)) + .Reshard(HloSharding::Replicate()) + .hlo(), + hlo->called_computations(), branch_args)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + return HandleSingleDevice(hlo); +} + +Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + if (hlo->sharding().IsReplicated()) { + SetPartitionedHlo(hlo, [&] { + // Run on a single device (0) and distribute the data to all other cores. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::AssignDevice(0)) + .hlo()); + } + auto clone = b_.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::AssignDevice(0)); + return PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(HloSharding::Replicate()) + .hlo(); + }); + return Status::OK(); + } + + TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); + SetPartitionedHlo(hlo, [&] { + // Replicate the operands and run partitioned Rng on all devices. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::Replicate()) + .hlo()); + } + return b_.AddInstruction(HloInstruction::CreateRng( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + hlo->random_distribution(), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) { + auto& operand = GetPartitionedHlo(hlo->operand(0)); + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1)) + .Reshard(HloSharding::Replicate()); + auto resharded_operand_and_window = operand.ReshardAsWindowedInput( + hlo->window(), hlo->sharding(), replicated_init.hlo()); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + + TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape, + ShapeInference::InferReduceWindowShape( + resharded_operand_and_window->sharded_input->shape(), + replicated_init.hlo()->shape(), + resharded_operand_and_window->shard_window, + hlo->to_apply()->ComputeProgramShape())); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_rw_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow( + sharded_rw_shape, resharded_operand_and_window->sharded_input, + replicated_init.hlo(), resharded_operand_and_window->shard_window, + hlo->to_apply())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape())); + return sharded_rw; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_rw, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + auto operand = GetPartitionedHlo(hlo->operand(0)); + auto source = GetPartitionedHlo(hlo->mutable_operand(1)); + if (hlo->sharding() != operand.sharding()) { + operand = operand.Reshard(hlo->sharding()); + } + if (hlo->sharding() != source.sharding()) { + source = source.Reshard(hlo->sharding()); + } + + // For F32 and BF16 types, we can use NaN padding to workaround the issue with + // low/high padding, since comparison will return false with NaN input. + if (hlo->shape().element_type() != F32 && + hlo->shape().element_type() != BF16) { + return DefaultAction(hlo); + } + + auto select = hlo->called_computations()[0]; + auto select_root = select->root_instruction(); + if (select_root->opcode() != HloOpcode::kCompare || + select_root->operand(0)->opcode() != HloOpcode::kParameter || + select_root->operand(1)->opcode() != HloOpcode::kParameter || + select_root->operand(0)->parameter_number() + + select_root->operand(1)->parameter_number() != + 1) { + return DefaultAction(hlo); + } + + float float_pad_value; + if (select_root->comparison_direction() == ComparisonDirection::kGe || + select_root->comparison_direction() == ComparisonDirection::kGt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = -std::numeric_limits::infinity(); + } else { + float_pad_value = std::numeric_limits::infinity(); + } + } else if (select_root->comparison_direction() == ComparisonDirection::kLe || + select_root->comparison_direction() == ComparisonDirection::kLt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = std::numeric_limits::infinity(); + } else { + float_pad_value = -std::numeric_limits::infinity(); + } + } else { + return DefaultAction(hlo); + } + + auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant( + hlo->shape().element_type() == BF16 + ? LiteralUtil::CreateR0( + static_cast(float_pad_value)) + : LiteralUtil::CreateR0(float_pad_value))); + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)) + .Reshard(HloSharding::Replicate()); + + auto partition_ordinals = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + + // The first window for each dimension that overlaps with the shard area. + std::vector first_window( + hlo->shape().rank()); + // The first window for each dimension that goes beyond with the shard area. + std::vector limit_window( + hlo->shape().rank()); + std::vector data_left_halo_sizes(hlo->shape().rank()); + std::vector data_right_halo_sizes(hlo->shape().rank()); + std::vector source_left_halo_sizes(hlo->shape().rank()); + std::vector source_right_halo_sizes(hlo->shape().rank()); + auto unpadded_data_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto unpadded_source_shard_shape = + MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding()); + auto source_shard_hlo = source.hlo(); + auto data_shard_hlo = operand.hlo(); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + // If stride > window_size, there will be gaps between windows. These gaps + // will also exist in the output, so we keep them during halo exchange. + // + // TODO(yuanzx): This could introduce overhead if partitions start at + // different offsets in a gap. + auto wd = hlo->window().dimensions(i); + if (wd.stride() > wd.size()) { + wd.set_size(wd.stride()); + } + // shard_size * i < stride * k - pad_low + window_size => + // k > (shard_size * i + pad_low - window_size) / stride => + // first_k == (shard_size * i + pad_low - window_size + stride) / stride + first_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + wd.padding_low() - wd.size() + wd.stride(), wd.stride()); + // shard_size * (i + 1) <= stride * k - pad_low => + // k >= (shard_size * i + shard_size + pad_low) / stride => + // limit_k == (shard_size * i + shard_size + pad_low + stride - 1) / + // stride + limit_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.padding_low() + + wd.stride() - 1, + wd.stride()); + source_left_halo_sizes[i] = + MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), 0, 1) - + first_window[i]; + source_right_halo_sizes[i] = + limit_window[i] - MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), + unpadded_source_shard_shape.dimensions(i), 1); + data_left_halo_sizes[i] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) - + OffsetCalculation( + HloOpcode::kMultiply, first_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)); + data_right_halo_sizes[i] = + OffsetCalculation( + HloOpcode::kMultiply, limit_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.stride() + + wd.padding_low() - wd.size(), + 1)); + + int64 max_windows = + (limit_window[i] - first_window[i]).MaxInRange(0, shard_count); + auto first_window_hlo = + first_window[i].Calculate(partition_ordinals[i], &b_); + // Padding on the source is filled with the init value so they do not change + // the data on overlapping windows. + auto resharded_source = ExchangeHaloAndGetValidData( + source_shard_hlo, source.base_shape(), source_left_halo_sizes[i], + source_right_halo_sizes[i], 0, + limit_window[i].Calculate(shard_count - 1), max_windows, i, + hlo->sharding(), first_window_hlo, replicated_init.hlo(), + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_source) { + return DefaultAction(hlo); + } + source_shard_hlo = *resharded_source; + + auto offset_start_in_data = + MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1) + .Calculate(first_window_hlo, &b_); + int64 padded_data_size = + (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() + + wd.size(); + int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size(); + auto resharded_data = ExchangeHaloAndGetValidData( + data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i], + data_right_halo_sizes[i], wd.padding_low(), padded_data_size, + data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value, + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_data) { + return DefaultAction(hlo); + } + data_shard_hlo = *resharded_data; + } + + Window window_on_shard = hlo->window(); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + auto reshard_wd = window_on_shard.mutable_dimensions(i); + // The shards are already explicitly padded. + reshard_wd->set_padding_low(0); + reshard_wd->set_padding_high(0); + } + + auto sharded_select_and_scatter = + b_.AddInstruction(HloInstruction::CreateSelectAndScatter( + data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard, + source_shard_hlo, replicated_init.hlo(), + hlo->called_computations()[1])); + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(), + shard_shape)) { + return sharded_select_and_scatter; + } + auto zero = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(shard_shape.rank(), zero); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) == 1) { + continue; + } + int64 pad_low = hlo->window().dimensions(i).padding_low(); + auto left_halo_size = + data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_); + if (data_left_halo_sizes[i].Calculate(0) == pad_low) { + slice_offsets[i] = left_halo_size; + } else { + auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i], + ComparisonDirection::kEq)); + auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(pad_low))); + slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary( + zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo, + left_halo_size)); + } + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_select_and_scatter, slice_offsets, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back( + GetPartitionedHlo(hlo->operand(i)) + .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i})) + .hlo()); + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateTuple(new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( + HloInstruction* hlo) { + TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution); + + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && + !rhs.sharding().IsTileMaximal()); + + const auto& dnums = hlo->convolution_dimension_numbers(); + + // Check if the operand shardings are aligned. Also we currently don't + // support partitioning non-spatial dimensions. + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + + Window window = hlo->window(); + std::vector reversed_rhs_dims; + for (int64 i = 0; i < window.dimensions_size(); ++i) { + if (window.dimensions(i).window_reversal()) { + reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i)); + } + } + if (!reversed_rhs_dims.empty()) { + // Make the reversed dims left-padded to prepare for window reversal. + auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims); + if (left_padded_rhs == nullptr) { + return DefaultAction(hlo); + } + left_padded_rhs->set_sharding(rhs.sharding()); + rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state()); + } + // Consider window reversal when resharding RHS or LHS. Note: this will not + // reverse the data in the shard. We use window reversal to do that. + auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding( + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices), + reversed_rhs_dims); + auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding( + hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims), + lhs_to_rhs_indices); + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != + 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero, reversed_rhs_dims); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims); + } + + // Reshard LHS by exchanging halo such that each shard computes the partial + // sum of the full shape result, and add AllReduce. + // + // The size of halo on each dimension can be calculated from the projection + // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers + // to the shard size of RHS and LHS, WC is the number of windows, and D is the + // window dilation. + // + // * offset(i): RHS * D * i - low_padding + // * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding + // + // Since shard i has LHS of range [i * LHS, (i + 1) * LHS) + // * left-halo: i * LHS - offset(i) + // = (LHS - RHS) * i + low_padding + // * right-halo: limit(i) - (i + 1) * LHS + // = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = + CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = + CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions(hlo->shape().rank()); + std::vector right_halo_size_functions(hlo->shape().rank()); + Window new_window = window; + + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + HloInstruction* lhs_with_halo = lhs.hlo(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + int64 rhs_shard_size_dilated = + (rhs_shard_size - 1) * wd.window_dilation() + 1; + + left_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low, + 1)); + right_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size_dilated - lhs_shard_size, + rhs_shard_size_dilated - lhs_shard_size + + wd.stride() * (window_count - 1) - padding_low, + 1)); + + // Exchange halo and concatenate. + int64 dim = dnums.input_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = padding_low; + int64 shard_size_with_halo = + wd.stride() * (window_count - 1) + rhs_shard_size_dilated; + + new_window.mutable_dimensions(i)->set_padding_low(0); + new_window.mutable_dimensions(i)->set_padding_high(0); + new_window.mutable_dimensions(i)->set_size(rhs_shard_size); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation()); + int64 padded_full_shape_size = 0; + auto concat = ExchangeHaloAndGetValidData( + lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero, + partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_, + /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + lhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, + hlo->convolution_dimension_numbers(), hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { + auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo); + if (dot_dnums) { + // Use HandleDotHelper() for convs that are actually einsums. + spmd::DotGeneralDimsMapping mapping; + for (const auto& dims : dot_dnums->batch_dims) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dims.lhs; + mapping.batch_dims.back().rhs = dims.rhs; + mapping.batch_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->contracting_dims) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dims.lhs; + mapping.contracting_dims.back().rhs = dims.rhs; + mapping.contracting_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->lhs_non_contracting_dims) { + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.lhs_non_contracting_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->rhs_non_contracting_dims) { + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.rhs_non_contracting_dims.back().output = dims.output; + } + auto create_sharded_conv = + [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, + spmd::SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_conv, + dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution( + *hlo, *dot_dnums, lhs_hlo, rhs_hlo)); + return b->AddInstruction(std::move(sharded_conv)); + }; + return HandleDotHelper(hlo, mapping, create_sharded_conv); + } + + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + const HloSharding& sharding = hlo->sharding(); + const auto& dnums = hlo->convolution_dimension_numbers(); + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + // Handling cases where both operands' shardings are aligned. We check that + // the LHS batch dimension is not partitioned because it is mapped to the + // output feature dimension in aligned_rhs_sharding, which are not the same + // dimension. + if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) { + if (options_.conv_halo_exchange_always_on_lhs) { + return HandleConvolutionTiledLhsAndRhs(hlo); + } else { + // Reshard RHS so that each shard computes the partial sum of the full + // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() + // that reshards LHS. + // + // The size of halo on each dimension can be calculated from the + // projection onto the RHS that shard i needs to read. RHS and LHS below + // refers to the shard size of RHS and LHS, WC is the number of windows, + // and D is the window dilation. + // + // * offset(i): LHS * i + low_padding - (WC - 1) * stride + // * limit(i): LHS * (i + 1) + low_padding + // + // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D) + // * left-halo: i * RHS - offset(i) + // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding + // * right-halo: limit(i) - (i + 1) * RHS + // = (i + 1) * (LHS - RHS * D) + low_pading + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + // We currently don't support partitioning input batch or output feature + // dimensions. + return lhs_sharding.tile_assignment().dim( + dnums.input_batch_dimension()) != 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeSizeInBytes(lhs.base_shape()) < + ShapeSizeInBytes(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + Window window = hlo->window(); + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = CeilOfRatio( + lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = CeilOfRatio( + rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions( + hlo->shape().rank()); + std::vector right_halo_size_functions( + hlo->shape().rank()); + Window new_window = window; + + // Data structures needed for Pad and DynamicSlice on LHS if needed. + bool need_dynamic_slice_lhs = false; + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + std::vector zero_padding(hlo->shape().rank()); + PaddingConfig pad_config = + window_util::MakeSymmetricPadding(zero_padding); + auto zero_s32 = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector dynamic_slice_start_indices( + hlo->shape().rank(), zero_s32); + Shape dynamic_slice_shape = lhs.hlo()->shape(); + Shape pad_shape = lhs.hlo()->shape(); + + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. It calculcates the halo sizes with dilation, so we apply + // CeilOfRatio({left,right}_halo_size, window_dilation). + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = + 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + left_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + (window_count - 1) * wd.stride() - padding_low + + wd.window_dilation() - 1, + wd.window_dilation())); + right_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), + lhs_shard_size - rhs_shard_size * wd.window_dilation() + + padding_low + wd.window_dilation() - 1, + wd.window_dilation())); + + // New RHS window size includes the maximum of both left and right + // halos. + int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange( + 1, shard_counts[i]) + + right_halo_size_functions[rhs_dimension].MaxInRange( + 0, shard_counts[i] - 1); + int64 new_window_size = + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size; + + // The amount of new low padding could be dynamic (e.g., window_dilation + // != 1), which requires pad (to the maximum) and dynamic slice on LHS. + // + // If we consider the first window, the offset of the dilated RHS that + // aligns with the first valid LHS element for shard i is 'padding_low + + // LHS * i'. When the left halo is added to RHS, the offset of the first + // RHS element is (RHS * i - left_halo) * window_dilation. The + // difference between the two values is the amount of padding_low we + // need on LHS. + auto new_padding_low_function = + OffsetCalculation( + HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension], + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, wd.window_dilation(), 1))) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + -padding_low, 1)); + + int64 new_padding_low_max = + new_padding_low_function.MaxInRange(0, shard_counts[i]); + int64 new_padding_low = new_padding_low_max; + int64 new_padding_high = window_count * wd.stride() + + (new_window_size - 1) * wd.window_dilation() - + new_padding_low - lhs_shard_size; + + // We do pad/dynamic-slice only when the padding is dynamic. + if (!new_padding_low_function.IsConstant()) { + need_dynamic_slice_lhs = true; + new_padding_low = 0; + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_low(new_padding_low_max); + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_high(new_padding_low_max); + pad_shape.set_dimensions(lhs_dimension, + lhs_shard_size + 2 * new_padding_low_max); + dynamic_slice_start_indices[lhs_dimension] = + (OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, new_padding_low_max, 1)) - + new_padding_low_function) + .Calculate(partition_ordinals[lhs_dimension], &b_); + dynamic_slice_shape.set_dimensions( + lhs_dimension, lhs_shard_size + new_padding_low_max); + } + + // Since the convolution RHS operand size increased with halos, adjust + // the window config accordingly. + new_window.mutable_dimensions(i)->set_padding_low(new_padding_low); + new_window.mutable_dimensions(i)->set_padding_high(new_padding_high); + new_window.mutable_dimensions(i)->set_size( + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size); + } + + HloInstruction* conv_lhs = lhs.hlo(); + if (need_dynamic_slice_lhs) { + auto pad = b_.AddInstruction( + HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config)); + conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, pad, dynamic_slice_start_indices, + dynamic_slice_shape.dimensions())); + } + + // Exchange halo and concatenate. + HloInstruction* rhs_with_halo = rhs.hlo(); + for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { + int64 dim = dnums.kernel_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = + left_halo_size_functions[dim].Calculate(0); + int64 shard_size_with_halo = new_window.dimensions(i).size(); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) - + left_halo_size_functions[dim]; + int64 padded_full_shape_size = + offset_on_padded_shape.Calculate(shard_counts[i] - 1) + + new_window.dimensions(i).size(); + auto concat = ExchangeHaloAndGetValidData( + rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), + zero, partition_ordinals[dim], collective_ops_creator_, + next_channel_id_, &b_, /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + rhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums, + hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + + if (!sharding.IsTileMaximal()) { + // We don't currently support sharding on output feature dimension. + if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) { + return DefaultAction(hlo); + } + + // Check if the operand and the output sharding are aligned. + std::vector input_to_output_indices(hlo->shape().rank()); + input_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + input_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + input_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + auto target_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices); + lhs = lhs.Reshard(target_operand_sharding); + + // Replicate the RHS. + rhs = rhs.Reshard(HloSharding::Replicate()); + + // Convolution window config does not include batch and feature dimensions, + // whereas ReshardAsWindowedInput() expects the same number of window + // dimensions as the rank of the operand. So add two more trivial + // dimensions. + std::vector ones(hlo->shape().rank(), 1); + auto operand_window = window_util::MakeWindow(ones); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) = + hlo->window().dimensions(i); + } + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + auto resharded_operand_and_window = lhs.ReshardAsWindowedInput( + operand_window, target_operand_sharding, zero); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + Window new_window; + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *new_window.add_dimensions() = + resharded_operand_and_window->shard_window.dimensions( + dnums.input_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + resharded_operand_and_window->sharded_input->shape(), + rhs.hlo()->shape(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums)); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_conv_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve( + sharded_conv_shape, resharded_operand_and_window->sharded_input, + rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(), + new_window, dnums, hlo->precision_config())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape())); + return sharded_conv; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_conv, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { + DotGeneralDimsMapping mapping; + const auto& dnums = hlo->dot_dimension_numbers(); + int64 next_output_dim = 0; + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); + mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); + mapping.batch_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); + mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); + mapping.contracting_dims.back().output = -1; + } + for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { + continue; + } + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = i; + mapping.lhs_non_contracting_dims.back().rhs = -1; + mapping.lhs_non_contracting_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { + continue; + } + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = -1; + mapping.rhs_non_contracting_dims.back().rhs = i; + mapping.rhs_non_contracting_dims.back().output = next_output_dim++; + } + auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, + SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_dot_shape, + ShapeInference::InferDotOpShape(l->shape(), r->shape(), + hlo->dot_dimension_numbers())); + return b->AddInstruction(HloInstruction::CreateDot( + sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), + hlo->precision_config())); + }; + return HandleDotHelper(hlo, mapping, create_sharded_dot); +} + +Status SpmdPartitioningVisitor::HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { + const HloSharding& lhs_sharding = hlo->operand(0)->sharding(); + const HloSharding& rhs_sharding = hlo->operand(1)->sharding(); + + // Similar to hlo_sharding_util::TransposeSharding(), but allows + // removing/adding non-partitioned dimensions. + auto transpose_sharding = + [&](const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src) -> absl::optional { + if (source.IsTileMaximal()) { + return source; + } + std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); + int64 skipped_tgt_dims = 0; + for (int64 i = 0; i < tgt_to_src.size(); ++i) { + if (tgt_to_src[i] < 0) { + skipped_tgt_dims++; + } else { + tgt_dims_skipping_new[i] = i - skipped_tgt_dims; + } + } + int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); + std::vector perm(src_to_tgt.size()); + for (int64 i = 0; i < src_to_tgt.size(); ++i) { + if (src_to_tgt[i] < 0) { + if (source.tile_assignment().dim(i) > 1) { + return absl::nullopt; + } + perm[src_to_tgt.size() - skipped_src_dims] = i; + skipped_src_dims--; + } else { + perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; + } + } + auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + if (skipped_tgt_dims == 0) { + return tgt_sharding; + } + auto reshape_tiles = tgt_sharding.tile_assignment(); + std::vector tgt_tiles(tgt_to_src.size(), 1); + for (int64 i = 0; i < tgt_tiles.size(); ++i) { + if (tgt_to_src[i] >= 0) { + tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); + } + } + reshape_tiles.Reshape(tgt_tiles); + return HloSharding::Tile(reshape_tiles); + }; + + std::vector lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1); + std::vector lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1); + std::vector rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1); + std::vector rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1); + std::vector output_to_lhs_indices(hlo->shape().rank(), -1); + std::vector output_to_rhs_indices(hlo->shape().rank(), -1); + auto populate_indices_mapping = + [&](const DotGeneralDimsMapping::DimsMapping& mapping) { + if (mapping.lhs >= 0) { + lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; + lhs_to_output_indices[mapping.lhs] = mapping.output; + } + if (mapping.rhs >= 0) { + rhs_to_lhs_indices[mapping.rhs] = mapping.lhs; + rhs_to_output_indices[mapping.rhs] = mapping.output; + } + if (mapping.output >= 0) { + output_to_lhs_indices[mapping.output] = mapping.lhs; + output_to_rhs_indices[mapping.output] = mapping.rhs; + } + }; + for (const auto& mapping : dims_mapping.batch_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + auto lhs_sharding_transposed_to_match_rhs = + transpose_sharding(lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices); + auto rhs_sharding_transposed_to_match_lhs = + transpose_sharding(rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices); + auto lhs_sharding_transposed_to_match_output = transpose_sharding( + lhs_sharding, lhs_to_output_indices, output_to_lhs_indices); + auto rhs_sharding_transposed_to_match_output = transpose_sharding( + rhs_sharding, rhs_to_output_indices, output_to_rhs_indices); + auto output_sharding_transposed_to_match_lhs = transpose_sharding( + hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices); + auto output_sharding_transposed_to_match_rhs = transpose_sharding( + hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices); + + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. + auto get_partitions_for_dims = + [&](const HloSharding& sharding, + absl::Span dims, + int lhs_rhs_or_output) { + int64 partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_rhs_or_output == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else if (lhs_rhs_or_output == 1) { + partitions *= sharding.tile_assignment().dim(dim.rhs); + } else { + CHECK_EQ(lhs_rhs_or_output, 2); + partitions *= sharding.tile_assignment().dim(dim.output); + } + } + return partitions; + }; + const int64 lhs_batch_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0); + const int64 rhs_batch_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1); + const int64 output_batch_partitions = + get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2); + const int64 lhs_contracting_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0); + const int64 rhs_contracting_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1); + const int64 lhs_non_contracting_partitions = get_partitions_for_dims( + lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0); + const int64 rhs_non_contracting_partitions = get_partitions_for_dims( + rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1); + const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2); + const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2); + + auto& lhs = GetPartitionedHlo(hlo->operand(0)); + auto& rhs = GetPartitionedHlo(hlo->operand(1)); + // LHS and RHS are partitioned the same way and only partitioned in batch + // dimensions. + if (lhs_batch_partitions == rhs_batch_partitions && + rhs_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_rhs == rhs_sharding) { + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + dot->set_sharding(*lhs_sharding_transposed_to_match_output); + return PartitionedHlo(dot, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // Try emit batch-partitioned einsum with one operand resharded. Returns + // whether the attempt succeeds. If may_reshard_with_allreduce is false, + // reshard must be done using all-to-all; otherwise this attempt fails. + auto try_emit_output_batch_partitioned_einsum_with_reshard = + [&](bool may_reshard_with_allreduce) -> StatusOr { + // LHS and output are batch partitioned in the same way. + if (lhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(rhs.sharding(), + *lhs_sharding_transposed_to_match_rhs)) { + return false; + } + auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + // RHS and output are batch partitioned in the same way. + if (rhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(lhs.sharding(), + *rhs_sharding_transposed_to_match_lhs)) { + return false; + } + auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + return false; + }; + + { + // Try batch-parallel by resharding one operand, and not using all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(false)); + if (emitted) { + return Status::OK(); + } + } + + // Try to emit windowed DotGeneral when one operand is partitioned in the same + // way as the output along non-contracting dimensions, but the other operand + // is tiled in other dimensions. + auto emit_windowed_dot_general = [&](int64 matching_operand, + int64 windowing_operand, + bool windowed_at_contracting_dims, + bool windowed_at_batch_dims) { + CHECK_EQ(matching_operand + windowing_operand, 1); + CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims); + auto unpadded_result_buffer_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto padded_result_buffer_shape = unpadded_result_buffer_shape; + // For windowing at batch/non-contracting dims, we produce the result one + // partition at a time, so we need to pad the shape in case of uneven + // partitioning in order to make dynamic-update-slice in-bound. + if (!windowed_at_contracting_dims) { + padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning( + padded_result_buffer_shape, + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output); + } + // Mask the padding area of the windowed operand with zero if there is + // uneven partitioning. + if (windowed_at_contracting_dims) { + auto& to_mask = windowing_operand == 0 ? lhs : rhs; + to_mask = + to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type())))); + } + auto result_buffer = CreateZero(padded_result_buffer_shape, &b_); + auto iteration = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + // Create a while loop that computes one window per iteration. During each + // iteration, each partition sends its input window to its neighbor using + // collective-permute for the next iteration. + SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_); + auto param = body_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto l = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0)); + auto r = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1)); + auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), param, 2)); + auto i = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3)); + + auto partition_id = collective_ops_creator_.create_partition_id(&body_b); + auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, partition_id)); + auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))); + data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count)); + auto dot_lhs = l; + auto dot_rhs = r; + if (windowed_at_contracting_dims || windowed_at_batch_dims) { + // Slice the matching operand according to the partitioned contracting + // dimensions on the windowed operand. We do this by treating the matching + // operand as replicated, and resharding it to match the windowed operand. + auto slice_operand = matching_operand == 0 ? l : r; + slice_operand->set_sharding(HloSharding::Replicate()); + auto state = MakePartitioningState(); + state.b = &body_b; + state.partition_id = data_partition_id; + auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) + .Reshard(windowing_operand == 0 + ? *lhs_sharding_transposed_to_match_rhs + : *rhs_sharding_transposed_to_match_lhs) + .hlo(); + slice_operand->clear_sharding(); + if (matching_operand == 0) { + dot_lhs = slice; + } else { + dot_rhs = slice; + } + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(dot_lhs, dot_rhs, &body_b)); + if (windowed_at_contracting_dims) { + // Accumulate the partial output to the result buffer. + o = body_b.AddInstruction( + HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot)); + } else { + // The windowing operand is partitioned along batch/non-contracting + // dimensions, so we need a dynamic-update-slice to save the partial + // output in the result buffer. + auto offsets = MakePartitionOffsets( + o->shape(), + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output, + data_partition_id, &body_b); + o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + o->shape(), o, dot, offsets)); + } + + // ++i + i = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, + body_b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))))); + auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), i, + body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + // Collective-permute for the next window. We don't need it for the last + // iteration, so we use a conditional around the collective-permute. + HloInstruction* conditional; + { + SpmdBuilder cp_b("window_collective_permute", visiting_hlo_); + { + auto p = cp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + std::vector> sd_pairs(num_partitions_); + for (int64 source = 0; source < num_partitions_; ++source) { + // 0 -> n-1, 1 -> 0, 2 -> 1, ... + sd_pairs[source] = {source, + (source - 1 + num_partitions_) % num_partitions_}; + } + collective_ops_creator_.create_cross_partition_collective_permute( + &cp_b, p, sd_pairs, (*next_channel_id_)++); + } + SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_); + { + ncp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + } + conditional = body_b.AddInstruction(HloInstruction::CreateConditional( + windowing_operand == 0 ? l->shape() : r->shape(), has_more, + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(cp_b.Build()), + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(ncp_b.Build()))); + } + if (windowing_operand == 0) { + l = conditional; + } else { + r = conditional; + } + body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i})); + + SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_); + auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement( + iteration->shape(), cond_param, 3)); + cond_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), cond_i, + cond_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile( + cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()), + module_->AddEmbeddedComputation(body_b.Build()), + b_.AddInstruction(HloInstruction::CreateTuple( + {lhs.hlo(), rhs.hlo(), result_buffer, iteration})))); + windowed_dot_general_loops_.push_back({while_loop, windowing_operand, + windowed_at_contracting_dims, + windowed_at_batch_dims}); + SetPartitionedHlo(hlo, [&] { + auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), while_loop, 2)); + if (!ShapeUtil::Compatible(padded_result_buffer_shape, + unpadded_result_buffer_shape)) { + result = b_.AddInstruction(HloInstruction::CreateSlice( + unpadded_result_buffer_shape, result, + std::vector(padded_result_buffer_shape.rank(), 0), + unpadded_result_buffer_shape.dimensions(), + std::vector(padded_result_buffer_shape.rank(), 1))); + } + return result; + }); + return Status::OK(); + }; + if (output_lhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_lhs == lhs_sharding && + ShapeSizeInBytes(hlo->operand(1)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (rhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, true, false); + } + if (rhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, false); + } + if (rhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, true); + } + } + if (output_rhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_rhs == rhs_sharding && + ShapeSizeInBytes(hlo->operand(0)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (lhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, true, false); + } + if (lhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, false); + } + if (lhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, true); + } + } + + { + // Try batch-parallel by resharding one operand, and allowing all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(true)); + if (emitted) { + return Status::OK(); + } + } + + // LHS and RHS have the same partitioned contracting dimensions. + if (lhs_contracting_partitions == rhs_contracting_partitions && + lhs_contracting_partitions == num_partitions_) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + // Pad both sides with zero, since NaN at one side cannot be masked by zero + // on the other side. + if (ShapeSizeInBytes(lhs.base_shape()) < + ShapeSizeInBytes(rhs.base_shape())) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // LHS and output have the same partitioned non-contracting dimensions. + if (lhs_non_contracting_partitions == num_partitions_ && + output_lhs_non_contracting_partitions == num_partitions_ && + lhs_sharding == hlo->sharding()) { + auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs_replicated, &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // RHS and output have the same partitioned non-contracting dimensions. + if (rhs_non_contracting_partitions == num_partitions_ && + output_rhs_non_contracting_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs_replicated, rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Output is batch partitioned. + if (output_batch_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along LHS non-contracting dimensions. + if (output_lhs_non_contracting_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + TF_ASSIGN_OR_RETURN( + auto dot, + create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along RHS non-contracting dimensions. + if (output_rhs_non_contracting_partitions == num_partitions_) { + auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Returns true if it is beneficial to reshard the operand at `operand_idx` + // across the contracting dimension. + const auto should_partition_contracting_dim = [&](int64 operand_idx) { + if (!hlo->sharding().IsReplicated()) { + return false; + } + + if (operand_idx == 0) { + // If LHS and output are replicated, we compare the cost of all-gather + // on RHS vs all-reduce on the output. + return (rhs_contracting_partitions == num_partitions_) && + lhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(1)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } else { + return (lhs_contracting_partitions == num_partitions_) && + rhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(0)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } + }; + + // When the output is replicated and one of the operands is partitioned along + // contracting dimension, align the other operand to be partitioned along + // the contracting dimensions. + if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) || + should_partition_contracting_dim(1))) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (should_partition_contracting_dim(0)) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo(); + }); + return Status::OK(); + } + + return DefaultAction(hlo); +} + +namespace { + +// Finds a cluster of nodes that produce the inputs for `hlo` which only depend +// on small operands, which means the cluster should start with broadcasts, +// constants and iotas. All other internal nodes must be non-side-effecting +// elemntwise ops. Returns the set of nodes, and the small operands. E.g., for +// the following graph, +// +// a -> broadcast -> multiply +// iota ---> add--/ +// constant/ +// +// FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return +// <{broadcast, iota, constant, add, multiply}, [a]>. +std::pair, std::vector> +FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) { + std::unordered_set nodes_found; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector worklist; + worklist.push_back(hlo); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (nodes_found.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast || + inst->opcode() == HloOpcode::kConstant || + inst->opcode() == HloOpcode::kIota) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + auto res = new_operands_set.emplace(o); + if (res.second) { + new_operands.push_back(o); + } + } + } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + worklist.push_back(o); + } + } else { + nodes_found.clear(); + new_operands.clear(); + break; + } + } + return {std::move(nodes_found), std::move(new_operands)}; +} + +// Moves a cluster of memory-reducing nodes into the windowed dot-general loop +// on contracting dimensions. Such a loop has a dynamic slice on the +// non-windowed operand. If we move the input nodes into the loop, the +// dynamic-slice could be merged with them by later optimization passes, which +// reduces memory. +// +// small_operands small_operands +// | | +// input_nodes loop { | +// | => input_nodes +// loop { | | +// dynamic-slice dynamic-slice +// ... ... +// } } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes. +Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + HloInstruction* loop, int64 non_windowed_operand_index) { + auto input_tuple = loop->mutable_operand(0); + auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index); + auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand); + auto to_sink = std::move(input_nodes.first); + auto new_operands = std::move(input_nodes.second); + if (to_sink.empty()) { + return Status::OK(); + } + auto computation = loop->parent(); + // Replace the old operand with a tuple of the found small operands. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_input_subtuple)); + + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto old_body_param_users = body_param->users(); + // Update all tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body->root_instruction()}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), + {non_windowed_operand_index}) = + new_input_subtuple->shape(); + } + // Now update the loop body. + auto new_operand_tuple_inside = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, non_windowed_operand_index)); + TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_operand_tuple_inside)); + + // Create nodes inside the loop body. + std::vector worklist; + std::unordered_map outside_to_inside; + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_operand_tuple_inside, i)); + add_users_if_available(new_operands[i]); + } + // HLOs to sink without operands. + std::vector nullaries_to_sink; + for (auto inst : to_sink) { + if (inst->operand_count() == 0) { + nullaries_to_sink.push_back(inst); + } + } + // Sort nullaries_to_sink to make it deterministic. + absl::c_sort(nullaries_to_sink, + [](const HloInstruction* a, const HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + for (auto inst : nullaries_to_sink) { + worklist.push_back(inst); + } + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + std::vector inst_new_operands(inst->operand_count()); + for (int64 i = 0; i < inst->operand_count(); ++i) { + inst_new_operands[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction( + inst->CloneWithNewOperands(inst->shape(), inst_new_operands)); + add_users_if_available(inst); + } + TF_RET_CHECK(outside_to_inside.count(old_operand) > 0); + for (auto ou : old_body_param_users) { + if (ou->opcode() == HloOpcode::kGetTupleElement && + ou->tuple_index() == non_windowed_operand_index) { + TF_RETURN_IF_ERROR( + ou->ReplaceAllUsesWith(outside_to_inside[old_operand])); + TF_RETURN_IF_ERROR(body->RemoveInstruction(ou)); + } + } + return Status::OK(); +} + +// Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into +// the windowed dot-general loop on non-contracting dimensions. Such a loop has +// a dynamic-update-slice at the output. If we move the user nodes into the loop +// and before the dynamic-update-slice, the user nodes can operate on smaller +// shapes, which reduces memory. +// +// small_operands small_operands +// | | => | | +// | | loop { loop { | | +// | | conv | broadcast conv +// | | | | | / +// | | dynamic-update-slice | dynamic-slice / +// | | | | | / +// | | } | | multiply----- +// |broadcast / | / +// | | / reduce +// |multiply-- | +// \ | dynamic-update-slice +// reduce } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes (broadcast). +Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + HloInstruction* loop) { + CHECK_EQ(loop->user_count(), 1); + // There should be a single direct user of the while loop, which is the + // gte for element 2, i.e., the dot output. + auto user_gte = loop->users().front(); + CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement); + CHECK_EQ(user_gte->tuple_index(), 2); + auto computation = loop->parent(); + + // Find the reduce outputs and the input nodes they depend on, if input nodes + // only have small operands. + std::unordered_set to_move; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector reduce_outputs; + std::vector worklist; + Shape padded_shape = user_gte->shape(); + Shape unpadded_shape = user_gte->shape(); + auto original_output = user_gte; + + if (user_gte->user_count() == 1 && + user_gte->users().back()->opcode() == HloOpcode::kSlice) { + original_output = user_gte->users().back(); + unpadded_shape = original_output->shape(); + } + for (auto u : original_output->users()) { + worklist.push_back(u); + } + to_move.insert(original_output); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (to_move.count(inst) > 0) { + continue; + } + // We only support reduces with simple reduction function, since we may need + // to accumulate across iterations manually. + if (inst->opcode() == HloOpcode::kReduce && + inst->to_apply()->instruction_count() == 3 && + inst->to_apply()->num_parameters() == 2 && + inst->to_apply()->root_instruction()->IsElementwise()) { + to_move.insert(inst); + auto other_operand = inst->mutable_operand(1); + auto res = new_operands_set.emplace(other_operand); + if (res.second) { + new_operands.push_back(other_operand); + } + reduce_outputs.push_back(inst); + } else if (inst != computation->root_instruction() && + inst->user_count() > 0 && inst->IsElementwise() && + !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + // For an elementwise op, we need to make sure that they depend on only + // nodes already in to_move and nodes with small operands. + bool can_include = true; + for (auto operand : inst->operands()) { + if (to_move.count(operand) > 0) { + continue; + } + auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand); + if (find_result.first.empty()) { + can_include = false; + break; + } + for (auto n : find_result.first) { + to_move.insert(n); + } + for (auto new_operand : find_result.second) { + auto res = new_operands_set.insert(new_operand); + if (res.second) { + new_operands.push_back(new_operand); + } + } + } + if (!can_include) { + to_move.clear(); + break; + } + to_move.insert(inst); + for (auto u : inst->users()) { + worklist.push_back(u); + } + } else { + to_move.clear(); + break; + } + } + // If nothing is found, to_move could contain only original_output, or cleared + // by the above code. + if (to_move.size() <= 1) { + return Status::OK(); + } + + // We will replace the original loop output with reduce-shape outputs. Create + // the initial buffers before the loop. + for (auto out : reduce_outputs) { + auto padded_out_shape = out->shape(); + int64 operand_dim = 0; + int64 output_dim = 0; + while (output_dim < padded_out_shape.rank()) { + if (absl::c_linear_search(out->dimensions(), operand_dim)) { + // Dimension colapsed. + ++operand_dim; + continue; + } + // Kept dimensions have the same size of the padded shape. + padded_out_shape.set_dimensions(output_dim, + padded_shape.dimensions(operand_dim)); + ++operand_dim; + ++output_dim; + } + auto broadcast = + computation->AddInstruction(HloInstruction::CreateBroadcast( + padded_out_shape, + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(out->shape().element_type()))), + {})); + new_operands.push_back(broadcast); + } + + auto input_tuple = loop->mutable_operand(0); + // Create the new input subtuple that contains the small operands and the + // reduce-shape result buffers. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple)); + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto body_root = body->root_instruction(); + CHECK_EQ(body_root->opcode(), HloOpcode::kTuple); + // Update tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body_root}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) = + new_input_subtuple->shape(); + } + auto new_loop_input = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, 2)); + + // Now create the moved nodes inside the loop body. + std::unordered_map outside_to_inside; + worklist.clear(); + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_loop_input, i)); + add_users_if_available(new_operands[i]); + } + // The elementwise nodes will be created with sliced shape. The original loop + // output corresponds to the dynamic-update-slice's update slice. + auto dus = body_root->mutable_operand(2); + CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice); + outside_to_inside[original_output] = dus->mutable_operand(1); + add_users_if_available(original_output); + std::vector slice_offsets(padded_shape.rank()); + for (int64 i = 0; i < slice_offsets.size(); ++i) { + slice_offsets[i] = dus->mutable_operand(i + 2); + } + auto get_slice = [&](HloInstruction* padded) { + return body->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + padded->shape().element_type()), + padded, slice_offsets, dus->operand(1)->shape().dimensions())); + }; + // Helper functions to create nodes with small operands. + auto add_broadcast = [&](const HloInstruction* broadcast) { + auto padded_operand_shape = broadcast->operand(0)->shape(); + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + padded_operand_shape.set_dimensions( + i, padded_shape.dimensions(broadcast->dimensions(i))); + } + auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)], + padded_operand_shape, nullptr, body); + outside_to_inside[broadcast] = + get_slice(body->AddInstruction(broadcast->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + padded_operand_shape.element_type()), + {padded_operand}))); + }; + auto add_iota = [&](const HloInstruction* iota) { + outside_to_inside[iota] = + get_slice(body->AddInstruction(iota->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + iota->shape().element_type()), + {}))); + }; + auto add_constant = [&](const HloInstruction* constant) { + outside_to_inside[constant] = body->AddInstruction(constant->Clone()); + outside_to_inside[constant] = get_slice( + PadToShape(outside_to_inside[constant], + ShapeUtil::ChangeElementType( + padded_shape, constant->shape().element_type()), + nullptr, body)); + }; + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (outside_to_inside.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast) { + add_broadcast(inst); + } else if (inst->opcode() == HloOpcode::kIota) { + add_iota(inst); + } else if (inst->opcode() == HloOpcode::kConstant) { + add_constant(inst); + } else if (inst->opcode() == HloOpcode::kReduce) { + // This is an output, for which we has special handling later. + } else { + std::vector operands_inside(inst->operand_count()); + for (int64 i = 0; i < operands_inside.size(); ++i) { + operands_inside[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + inst->shape().element_type()), + operands_inside)); + } + add_users_if_available(inst); + } + std::vector new_outputs_inside(new_operands.size()); + for (int64 i = 0; i < new_outputs_inside.size(); ++i) { + new_outputs_inside[i] = outside_to_inside[new_operands[i]]; + } + // Now create the reduce outpus inside of the loop. + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + auto reduce_outside = reduce_outputs[i]; + CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce); + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto last_iter_result = outside_to_inside[new_operands[index_in_operand]]; + auto operand0 = outside_to_inside[reduce_outside->operand(0)]; + auto operand1 = outside_to_inside[reduce_outside->operand(1)]; + TF_ASSIGN_OR_RETURN(auto reduce_shape, + ShapeInference::InferReduceShape( + {&operand0->shape(), &operand1->shape()}, + reduce_outside->dimensions(), + reduce_outside->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = reduce_outside->shape().layout(); + std::vector reduce_dus_offsets; + // If any collapsed dimension is windowed, we need to accumulate with last + // iteration's result. If such a dimension has padding, we also need to mask + // off invalid data. + bool needs_accumulate = false; + std::vector dims_to_mask; + for (int64 i = 0; i < slice_offsets.size(); ++i) { + if (absl::c_linear_search(reduce_outside->dimensions(), i)) { + if (reduce_outside->operand(0)->shape().dimensions(i) != + operand0->shape().dimensions(i)) { + needs_accumulate = true; + if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) { + dims_to_mask.push_back(i); + } + } + continue; + } + reduce_dus_offsets.push_back(slice_offsets[i]); + } + // Mask off invalid data in collapsed dimensions. + for (int64 dim : dims_to_mask) { + auto iota = body->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::ChangeElementType(operand0->shape(), S32), dim)); + auto add = body->AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, + body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), slice_offsets[dim], {})))); + auto limit = body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), + body->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + reduce_outside->operand(0)->shape().dimensions(dim)))), + {})); + auto compare = body->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit, + ComparisonDirection::kLt)); + operand0 = body->AddInstruction(HloInstruction::CreateTernary( + operand0->shape(), HloOpcode::kSelect, compare, operand0, + body->AddInstruction(HloInstruction::CreateBroadcast( + operand0->shape(), operand1, {})))); + } + auto output_inside = + body->AddInstruction(reduce_outside->CloneWithNewOperands( + reduce_shape, {operand0, operand1})); + // Accumulate with previous results if needed. + if (needs_accumulate) { + auto input_slice = + body->AddInstruction(HloInstruction::CreateDynamicSlice( + output_inside->shape(), last_iter_result, reduce_dus_offsets, + output_inside->shape().dimensions())); + output_inside = body->AddInstruction(HloInstruction::CreateBinary( + output_inside->shape(), + reduce_outside->to_apply()->root_instruction()->opcode(), + output_inside, input_slice)); + } + // Dynamic-update-slice if needed. + if (!ShapeUtil::Compatible(output_inside->shape(), + last_iter_result->shape())) { + output_inside = + body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + last_iter_result->shape(), last_iter_result, output_inside, + reduce_dus_offsets)); + } + new_outputs_inside[index_in_operand] = output_inside; + } + // Body output. + auto new_output_inside = + body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside)); + TF_RETURN_IF_ERROR( + body_root->ReplaceOperandWithDifferentShape(2, new_output_inside)); + TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus)); + // Replace uses of the reduces outside the loop. + auto new_output_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_output_inside->shape(), loop, 2)); + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto new_output = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_outputs_inside[index_in_operand]->shape(), new_output_gte, + index_in_operand)); + if (!ShapeUtil::Compatible(new_output->shape(), + reduce_outputs[i]->shape())) { + new_output = computation->AddInstruction(HloInstruction::CreateSlice( + reduce_outputs[i]->shape(), new_output, + std::vector(new_output->shape().rank(), 0), + reduce_outputs[i]->shape().dimensions(), + std::vector(new_output->shape().rank(), 1))); + } + TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output)); + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i])); + } + return Status::OK(); +} + +} // namespace + +Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops( + HloComputation* computation) { + for (auto& loop : windowed_dot_general_loops_) { + if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims) { + // We have a dynamic-slice for the non-windowed operand in + // batch/contracting-dim windowed dot-general. So moving the + // broadcast/iota/elementwise ops into the loop could help reduce memory + // via fusion. + TF_RETURN_IF_ERROR( + SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + loop.while_loop, 1 - loop.windowed_operand)); + } + if (!loop.windowed_in_contracting_dims) { + // We have a dynamic-update-slice for the output in + // batch/non-contracting-dim windowed dot-general. So moving reduce ops + // into the loop could help reduce memory. + TF_RETURN_IF_ERROR( + MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + loop.while_loop)); + } + } + return Status::OK(); +} + +StatusOr SpmdPartitioningVisitor::DoPartition( + HloComputation* computation, const HloSharding& root_sharding) { + VLOG(2) << "Partitioning computation " << computation->name() << " for " + << num_replicas_ << " replicas and " << num_partitions_ + << " partitions"; + TF_RETURN_IF_ERROR(computation->Accept(this)); + + HloModule* module = computation->parent(); + auto new_root = + GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding); + auto new_computation = + module->AddEmbeddedComputation(b_.Build(new_root.hlo())); + TF_RETURN_IF_ERROR(DoCodeMotionForWindowedDotGeneralLoops(new_computation)); + + // Replace the original computation with the new SPMD computation. + std::unordered_map replacement; + replacement[computation] = new_computation; + module->ReplaceComputations(replacement); + return changed_; +} + +Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) { + return Unimplemented( + "PartitionId instruction is not supported for SPMD partitioning since " + "the meaning is ambiguous -- whether the instruction is replicated or " + "the data is replicated, and if the latter which data is replicated."); +} + +SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, + int64 num_replicas) { + return { + [](SpmdBuilder* b) { + return b->AddInstruction(HloInstruction::CreatePartitionId()); + }, + [num_replicas](SpmdBuilder* b, HloInstruction* operand, + HloComputation* reduction, int64 channel_id) { + return b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction, + CreateReplicaGroups(num_replicas), + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + }, + [](SpmdBuilder* b, HloInstruction* operand, + std::vector>& src_dst_pairs, + int64 channel_id) { + return b->AddInstruction(HloInstruction::CreateCollectivePermute( + operand->shape(), operand, src_dst_pairs, channel_id)); + }, + [](SpmdBuilder* b, absl::Span operands, + const std::vector& replica_groups, int64 channel_id, + absl::optional split_dimension) { + std::vector shapes(operands.size(), operands[0]->shape()); + const Shape output_shape = (shapes.size() == 1) + ? shapes[0] + : ShapeUtil::MakeTupleShape(shapes); + return b->AddInstruction(HloInstruction::CreateAllToAll( + output_shape, operands, replica_groups, + /*constrain_layout=*/false, channel_id, split_dimension)); + }, + [num_replicas, num_partitions]( + SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape, + const std::vector>& partition_subgroups, + int64 channel_id, int64 all_gather_dimension) { + std::vector device_groups; + device_groups.reserve(partition_subgroups.size() * num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + for (const auto& pgroup : partition_subgroups) { + device_groups.emplace_back(); + for (int64 pid : pgroup) { + device_groups.back().add_replica_ids(i * num_partitions + pid); + } + } + } + return b->AddInstruction(HloInstruction::CreateAllGather( + ag_shape, operand, all_gather_dimension, device_groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/true)); + }, + }; +} + +SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options) + : SpmdPartitioner( + num_partitions, num_replicas, std::move(options), + GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {} + +HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, + HloInstruction* operand, + const HloSharding& sharding, + int64 channel_id) { + CHECK(!sharding.IsTileMaximal()); + // Add one leading dimension to gather all partitions. + std::vector shape; + shape.push_back(1); + for (int64 dim : operand->shape().dimensions()) { + shape.push_back(dim); + } + auto reshape = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand)); + std::vector> partition_subgroups(1); + for (int64 pid : sharding.tile_assignment()) { + partition_subgroups[0].push_back(pid); + } + shape[0] = sharding.tile_assignment().num_elements(); + auto result = collective_ops_creator_.create_cross_partition_all_gather( + b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape), + partition_subgroups, channel_id, /*all_gather_dimension=*/0); + // If n > 1 dimensions are partitioned, split the leading dimension to n. + std::vector tiled_dims; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) > 1) { + tiled_dims.push_back(i); + } + } + if (tiled_dims.size() > 1) { + std::vector split_dim_shape; + split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank()); + for (int64 i : tiled_dims) { + split_dim_shape.push_back(sharding.tile_assignment().dim(i)); + } + for (int64 dim : operand->shape().dimensions()) { + split_dim_shape.push_back(dim); + } + result = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape), + result)); + } + // Transpose the gathered dimensions to next to their corresponding + // partitioned dimensions. + std::vector xpose_permutation(result->shape().rank()); + int64 split_dims_added = 0; + for (int64 i = 0; i < xpose_permutation.size(); ++i) { + if (sharding.tile_assignment().dim(i - split_dims_added) == 1) { + xpose_permutation[i] = i + tiled_dims.size() - split_dims_added; + } else { + xpose_permutation[i] = split_dims_added; + xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added; + split_dims_added++; + i++; + } + } + result = b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(result->shape(), xpose_permutation) + .ValueOrDie(), + result, xpose_permutation)); + // Reshape to the desired shape. + auto ag_shape = operand->shape(); + for (int64 i : tiled_dims) { + ag_shape.set_dimensions( + i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result)); + return result; +} + +StatusOr SpmdPartitioner::PartitionComputation( + HloComputation* computation, const HloSharding& root_sharding, + int64* next_channel_id, SpmdLogger* logger) { + auto visitor = + CreateVisitor(computation, num_partitions_, num_replicas_, + collective_ops_creator_, next_channel_id, logger, options_); + return visitor->DoPartition(computation, root_sharding); +} + +std::unique_ptr SpmdPartitioner::CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options) { + return absl::make_unique( + computation, num_partitions, num_replicas, collective_ops_creator, + next_channel_id, logger, std::move(options), this); +} + +StatusOr SpmdPartitioner::Run(HloModule* module) { + TF_RETURN_IF_ERROR(PreprocessSharding(module)); + + XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition( + *module, options_.report_instruction_count)); + + // Add the parameters' and output's shardings to the module. + std::vector entry_params_shardings; + for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) { + auto param = module->entry_computation()->parameter_instruction(i); + CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i; + entry_params_shardings.push_back(param->sharding()); + } + module->set_spmd_parameters_shardings(entry_params_shardings); + auto entry_root = module->entry_computation()->root_instruction(); + CHECK(entry_root->has_sharding()) << "Missing sharding in entry root."; + module->set_spmd_output_sharding(entry_root->sharding()); + + FlattenCallGraph flatten; + TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module)); + + SpmdLogger logger(options_.report_instruction_count); + auto program_shape = module->entry_computation()->ComputeProgramShape(); + int64 next_channel_id = hlo_query::NextChannelId(*module); + TF_ASSIGN_OR_RETURN( + bool partition_changed, + PartitionComputation( + module->entry_computation(), + module->entry_computation()->root_instruction()->sharding(), + &next_channel_id, &logger)); + changed |= partition_changed; + + // For the entry computation, make sure that the root instruction and the + // parameters preserve their signatures. + auto new_program_shape = module->entry_computation()->ComputeProgramShape(); + if (!options_.allow_module_signature_change) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.result(), new_program_shape.result())) + << "Result shape changed for the entry computation"; + TF_RET_CHECK(program_shape.parameters_size() == + new_program_shape.parameters_size()) + << "Parameter count changed for the entry computation"; + for (int64 i = 0; i < program_shape.parameters_size(); ++i) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.parameters(i), new_program_shape.parameters(i))) + << "Parameter shape changed for the entry computation"; + } + } else { + const auto& old_entry_layout = module->entry_computation_layout(); + // Shapes can change but the layout should still remain the same. + for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) { + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.parameter_shape(i), + new_program_shape.mutable_parameters(i))); + } + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.result_shape(), new_program_shape.mutable_result())); + + HloModuleConfig config = module->config(); + *config.mutable_entry_computation_layout() = + ComputationLayout(new_program_shape, /*ignore_layouts=*/false); + module->set_config(config); + } + + XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition( + *module, options_.report_instruction_count)); + XLA_VLOG_LINES(1, logger.MakeReport()); + + if (changed) { + HloPassPipeline pass("spmd-cleanup"); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(/*is_layout_sensitive=*/true); + pass.AddPass(); + TF_RETURN_IF_ERROR(pass.Run(module).status()); + } + + TF_RETURN_IF_ERROR(ClearShardingAttributes(module)); + return changed; +} + +Status SpmdPartitioner::PreprocessSharding(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) { + TF_RET_CHECK(hlo->has_sharding()) + << "Side-effect HLO must have sharding: " << hlo->ToString(); + TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) || + hlo->opcode() == HloOpcode::kInfeed) + << "Non-infeed side-effect HLO cannot have a replicated sharding:" + << hlo->ToString(); + } + + // For unassigned HLOs, annotate with replicated sharding. + // + // Among side-effecting ops, only Rng is allowed to omit the annotation. + // In that case, we currently force it to run on core 0, since we don't + // support partitioning or replicating the Rng op (the values depend on + // the seed provided to each device). + // + // TODO(hyouklee): Should we also convert single-device shardings (without + // side-effects) into replicated? + if (!hlo->has_sharding()) { + if (hlo->opcode() == HloOpcode::kRng) { + hlo->set_sharding(HloSharding::AssignDevice(0)); + } else { + hlo->set_sharding( + HloSharding::Single(hlo->shape(), HloSharding::Replicate())); + } + } else if (!hlo->sharding().IsTileMaximal()) { + std::vector available(num_partitions_); + std::iota(available.begin(), available.end(), 0); + TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding( + hlo->sharding(), available) + .size()) + << "num_partitions:" << num_partitions_ << "\n" + << "SPMD partitioner only supports tile sharding that includes all " + "partitions. If you didn't add this sharding annotation in the " + "model, please file a bug to XLA team.\n" + << hlo->ToString(); + } + } + } + + // Entry computation's parameter and root sharding must be either all + // replicated or all on a single device. + if (!options_.allow_module_signature_change) { + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry->root_instruction()->has_sharding()); + const HloSharding& root_sharding = entry->root_instruction()->sharding(); + TF_RET_CHECK(root_sharding.IsReplicated() || + root_sharding.UniqueDevice().has_value()) + << "Unsupported entry root sharding: " << root_sharding.ToString(); + + for (const HloInstruction* param : entry->parameter_instructions()) { + TF_RET_CHECK(param->has_sharding()); + TF_RET_CHECK(param->sharding().IsReplicated() || + param->sharding().UniqueDevice().has_value()) + << "Unsupported entry parameter sharding:" + << param->sharding().ToString(); + } + } + + return Status::OK(); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h new file mode 100644 index 00000000000..52e4c9021d8 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -0,0 +1,465 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace spmd { + +struct SpmdPartitionerOptions { + // Always exchange halo on LHS for all convolutions. If false, backprop filter + // convolution exchanges halo on RHS. + bool conv_halo_exchange_always_on_lhs = true; + + // The number of instructions to be reported for the highest memory profile + // instructions. + int64 report_instruction_count = 5; + + // The minimum size in MiB of an einsum operand to be considered using + // windowed implementation in an HLO loop. + int64 threshold_for_windowed_einsum_mib = 256; + + // Whether the entry computations' signature could change after partitioning. + bool allow_module_signature_change = false; +}; + +// Class to wrap the computation builder to capture information during SPMD +// transformation. +class SpmdBuilder : public HloComputation::Builder { + public: + SpmdBuilder(const std::string& name, HloInstruction* hlo) + : HloComputation::Builder(name) { + visiting_hlo_ = hlo; + } + HloInstruction* AddInstruction(std::unique_ptr instruction); + + const std::vector& derived_instructions( + HloInstruction* hlo) { + return instructions_.at(hlo); + } + + void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; } + + HloInstruction* visiting_hlo() const { return visiting_hlo_; } + + private: + // Currently visiting instruction. + HloInstruction* visiting_hlo_; + + // Map from the currently visiting (old) instruction to new instructions + // created during SPMD partitioning. + HloInstructionMap> instructions_; +}; + +// A set of functions that create the cross-partition collective ops. +struct SPMDCollectiveOpsCreator { + // Function used to create a partition ID HLO. + std::function create_partition_id; + + // Function used to create a cross-partition all-reduce HLO. + std::function + create_cross_partition_all_reduce; + + // Function used to create a cross-partition collective-permute HLO. + std::function>& src_dst_pairs, + int64 next_channel_id)> + create_cross_partition_collective_permute; + + // Function used to create a cross-partition all-to-all HLO. + std::function operands, + const std::vector& replica_groups, int64 channel_id, + absl::optional split_dimension)> + create_cross_partition_all_to_all; + + // Function used to create a cross-partition all-gather HLO. This is optional: + // if it is nullptr, the partitioner will use all-reduce instead. + std::function>& partition_subgroups, + int64 channel_id, int64 all_gather_dimension)> + create_cross_partition_all_gather; +}; + +// Create a default SPMDCollectiveOpsCreator. +SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, + int64 num_replicas); + +// Logger to report memory usage during SPMD partitioning. +class SpmdLogger { + public: + explicit SpmdLogger(int64 report_instruction_count) + : report_instruction_count_(report_instruction_count) {} + static std::string ReportBeforePartition(const HloModule& module, + int64 report_instruction_count); + static std::string ReportAfterPartition(const HloModule& module, + int64 report_instruction_count); + + // Registers the logging for the groups of instructions created to transform + // the given hlo. + void RegisterLogEntry(HloInstruction* hlo, + const std::vector& group); + + std::string MakeReport(); + + private: + template + static std::string ReportMemoryUsage(const HloModule& module, const F& filter, + int64 report_instruction_count); + + // A vector of logging messages (one for each original HLO instruction), where + // the first integer of the pair represents the size of the HBM used. + std::vector> entries_; + + int64 report_instruction_count_; +}; + +class SpmdPartitioningVisitor; + +class SpmdPartitioner : public HloModulePass { + public: + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options); + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options, + SPMDCollectiveOpsCreator collective_ops_creator) + : num_partitions_(num_partitions), + num_replicas_(num_replicas), + options_(std::move(options)), + collective_ops_creator_(std::move(collective_ops_creator)) {} + absl::string_view name() const override { return "spmd-partitioning"; } + StatusOr Run(HloModule* module) override; + + // Transforms the given computation with SPMD instructions, replacing it with + // a new computation. + StatusOr PartitionComputation(HloComputation* computation, + const HloSharding& root_sharding, + int64* next_channel_id, + SpmdLogger* logger); + + // Creates all-gather based on HloSharding. Can be overridden to customize. + // The default uses a single all-gather even if there are multiple sharded + // dimensions, and adds potential reshapes and transposes to achieve that. + // If it returns false, the partitioner will fall back to all-reduce. + virtual HloInstruction* AllGatherShards(SpmdBuilder* b, + HloInstruction* operand, + const HloSharding& sharding, + int64 channel_id); + + protected: + virtual std::unique_ptr CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options); + + // Verify that the sharding of instructions in the module are valid, and also + // fill in missing sharding information. + Status PreprocessSharding(HloModule* module); + + const int64 num_partitions_; + const int64 num_replicas_; + + SpmdPartitionerOptions options_; + SPMDCollectiveOpsCreator collective_ops_creator_; +}; + +// Class describes partition state of the data represented by an HLO created +// during SPMD partitioning pass. +// +// Data on some devices may include padding region, if the base (full) shape +// could not be evenly partitioned. +class PartitionedHlo { + public: + // Return value for ReshardAsWindowedInput which describes the resharded HLO, + // the window for the user on the shard, and if necessary, the dynamic slice + // offsets to be applied to the output of the op being sharded. + struct WindowedInputShardReturnValue { + HloInstruction* sharded_input; + Window shard_window; + absl::optional> dynamic_slice_index_on_output; + }; + // A cache for resharding each partitioned HLO. + struct ReshardCache { + struct PerHloCache { + std::vector> reshard_cache; + std::vector< + std::tuple> + window_reshard_cache; + }; + std::unordered_map per_hlo_cache; + }; + struct PartitioningState { + SpmdBuilder* b; + HloModule* module; + int64 num_replicas; + HloInstruction* partition_id; + SPMDCollectiveOpsCreator collective_ops_creator; + int64* next_channel_id; + ReshardCache* reshard_cache; + SpmdPartitioner* partitioner; + }; + PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) + : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { + CHECK(hlo->has_sharding()) + << "PartitionedHlo is missing sharding:" << hlo->ToString(); + // If the tuple shape instruction does not have a tuple sharding, reassign + // to use the tuple sharding. Reshard() implementation assumes this. + if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { + hlo_->set_sharding( + hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); + } + } + + // Reshards the current SPMD instruction to a new sharding. Could only modify + // the reshard cache. + PartitionedHlo Reshard(const HloSharding& target); + + // Pads the garbage area of the output with the provided value. Normally, + // unevenly partitioned dimensions are padded on the right, but this function + // allows specifying left-padded dimensions, which can be used during the + // handling of kReverse, etc. + PartitionedHlo PadWithValue( + HloInstruction* pad_value, + absl::Span left_padded_dims = {}) const; + + // Returns the SPMD instruction. + HloInstruction* hlo() const { return hlo_; } + + // Returns the sharding of the SPMD instruction. + const HloSharding& sharding() const { return hlo_->sharding(); } + + // Original full shape of the data. + const Shape& base_shape() const { return base_shape_; } + + int64 NewChannel() const { return (*state_.next_channel_id)++; } + + // Reshards the HLO to a usable partitioned input for a windowed user. Could + // only modify the reshard cache. + absl::optional ReshardAsWindowedInput( + const Window& window, const HloSharding& target, + HloInstruction* pad_value, bool mask_invalid_region = true); + + const PartitioningState& state() const { return state_; } + + private: + // Same as Reshard except that it does not explicitly modify the reshard + // cache, although it would indirectly modify by calling Replicate(). + PartitionedHlo ReshardNoCache(const HloSharding& target); + + // Helper function to replicate the data on all devices. Could only modify + // the reshard cache. + PartitionedHlo Replicate(); + + // Helper function to broadcast data from a single device to all devices. + PartitionedHlo Broadcast() const; + + // Helper function to reshard the tensor using AllToAll (instead of the + // default of Replicate followed by Slice). + PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const; + + // Helper function to reshard the tensor using CollectivePermute. + PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; + + // SPMD instruction. + HloInstruction* hlo_; + + // The original shape of the data before SPMD transformation is applied. + Shape base_shape_; + + PartitioningState state_; +}; + +struct DotGeneralDimsMapping { + // The dimension numbers for the operands and output corresponding to a + // logical dimension (e.g., batch, contracting, non-contracting). If an + // operand or the output doesn't have the logical dimension, it is set to + // -1. + struct DimsMapping { + int64 lhs; + int64 rhs; + int64 output; + }; + std::vector batch_dims; + std::vector contracting_dims; + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; +}; + +class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { + public: + SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options, SpmdPartitioner* partitioner); + + Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllReduce(HloInstruction* hlo) override; + Status HandleBroadcast(HloInstruction* hlo) override; + Status HandleConstant(HloInstruction* hlo) override; + Status HandleCustomCall(HloInstruction* hlo) override; + Status HandleDot(HloInstruction* hlo) override; + Status HandleDynamicSlice(HloInstruction* hlo) override; + Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; + Status HandleGather(HloInstruction* hlo) override; + Status HandleGetTupleElement(HloInstruction* hlo) override; + Status HandleInfeed(HloInstruction* hlo) override; + Status HandleOutfeed(HloInstruction* hlo) override; + Status HandlePad(HloInstruction* hlo) override; + Status HandleParameter(HloInstruction* hlo) override; + Status HandleReduce(HloInstruction* hlo) override; + Status HandleReverse(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; + Status HandleConditional(HloInstruction* hlo) override; + Status HandleReduceWindow(HloInstruction* hlo) override; + Status HandleSelectAndScatter(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleRng(HloInstruction* hlo) override; + Status HandleConvolution(HloInstruction* hlo) override; + Status HandleConcatenate(HloInstruction* hlo) override; + Status HandleScatter(HloInstruction* hlo) override; + Status HandleSlice(HloInstruction* hlo) override; + Status HandleSort(HloInstruction* hlo) override; + Status HandleTranspose(HloInstruction* hlo) override; + Status HandleReshape(HloInstruction* hlo) override; + Status HandleIota(HloInstruction* hlo) override; + Status HandlePartitionId(HloInstruction* hlo) override; + + // Handles convolution where both LHS and RHS operands are tiled. + Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo); + + // Implementation of dot partitioning given DotGeneralDimsMapping. + Status HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot); + + // Common handle for elementwise HLOs. + Status HandleElementwise(HloInstruction* hlo); + + // Common handle for HLOs that runs on a single device. + Status HandleSingleDevice(const HloInstruction* hlo); + + // Returns the PartitionedHlo that corresponds to the original hlo. + PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 1); + return partitioned_instructions_.find(hlo)->second; + } + + // Sets the PartitionedHlo for the original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const PartitionedHlo& partitioned_hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 0); + partitioned_instructions_.emplace(hlo, partitioned_hlo); + changed_ = true; + } + + // Convenient wrapper that creates PartitionedHlo from the result of the func + // and maps it to the given original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const std::function& func) { + HloInstruction* new_hlo = func(); + new_hlo->set_sharding(hlo->sharding()); + new_hlo->set_metadata(hlo->metadata()); + SetPartitionedHlo( + hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); + changed_ = true; + } + + int64 NewChannel() { return (*next_channel_id_)++; } + + PartitionedHlo::PartitioningState MakePartitioningState() { + PartitionedHlo::PartitioningState state; + state.b = &b_; + state.module = module_; + state.num_replicas = num_replicas_; + state.partition_id = partition_id_; + state.collective_ops_creator = collective_ops_creator_; + state.next_channel_id = next_channel_id_; + state.reshard_cache = &reshard_cache_; + state.partitioner = partitioner_; + return state; + } + + SpmdBuilder* builder() { return &b_; } + + StatusOr DoPartition(HloComputation* computation, + const HloSharding& root_sharding); + + private: + Status Preprocess(HloInstruction* hlo) override; + Status Postprocess(HloInstruction* hlo) override; + + // Performs code motion for windowed dot-general loops in + // windowed_dot_general_loops_. Invoked after the visitor finishes traversing + // the graph. + Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation); + + bool changed_; + HloModule* module_; + int64 num_partitions_; + int64 num_replicas_; + + SPMDCollectiveOpsCreator collective_ops_creator_; + + // Tracks the next channel id to use for cross-partition all-reduce. + int64* next_channel_id_; + SpmdBuilder b_; + + HloInstruction* partition_id_; + + PartitionedHlo::ReshardCache reshard_cache_; + + // Mapping from the instruction in the original computation to the new SPMD + // partitioned instruction. + ConstHloInstructionMap partitioned_instructions_; + + // Information about a loop created for windowed dot-general. Used when + // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor + // finishes traversing the graph. + struct WindowedDotGeneralLoop { + HloInstruction* while_loop; + int64 windowed_operand; + bool windowed_in_contracting_dims; + bool windowed_in_batch_dims; + }; + std::vector windowed_dot_general_loops_; + + HloInstruction* visiting_hlo_; + SpmdLogger* logger_; + const SpmdPartitionerOptions options_; + SpmdPartitioner* partitioner_; +}; + +} // namespace spmd +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc new file mode 100644 index 00000000000..1f0b1d06c1f --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -0,0 +1,3771 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace spmd { +namespace { + +using ::testing::_; +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; + +class SpmdPartitioningTest : public HloTestBase { + public: + StatusOr> PartitionComputation( + const char* hlo_module, int64 num_devices, + bool conv_halo_exchange_always_on_lhs = true) { + // Some tests (BackpropFilter convs) set this flag false to test two + // different paths of the implementation. + SpmdPartitionerOptions options; + options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs; + options.allow_module_signature_change = true; + auto collective_ops_creator = + GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1); + // Do not use all-gather for pattern-matching purpose, as the partitioner + // might create reshape/transposes around it. + collective_ops_creator.create_cross_partition_all_gather = nullptr; + + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( + hlo_module, GetModuleConfigForTest())); + HloPassPipeline pass("spmd-partitioning"); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + pass.AddPass(num_devices, /*num_replicas=*/1, options, + collective_ops_creator); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); + return StatusOr>(std::move(module)); + } +}; + +TEST_F(SpmdPartitioningTest, InvalidSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4); + EXPECT_FALSE(module_status.status().ok()); + EXPECT_THAT(module_status.status().ToString(), + ::testing::HasSubstr( + "only supports tile sharding that includes all partitions")); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce( + op::Select(op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]"))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + VLOG(1) << module->ToString(); + EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]")))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), + sharding={devices=[2,1]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Copy(op::DynamicSlice( + op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Constant(), op::Broadcast())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())), + op::Shape("s32[1,3]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]"))))); +} + +TEST_F(SpmdPartitioningTest, TiledToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]")))))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledEven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf( + op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))), + op::Shape("s32[8,1]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1} + ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll( + op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]"))))))))))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0), + sharding={{maximal device=1}, {maximal device=1}} + %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0, + sharding={maximal device=0} + %gte.1 = u32[] get-tuple-element(%param.0), index=1, + sharding={maximal device=0} + ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1), + sharding={{maximal device=0},{maximal device=0}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + EXPECT_THAT(root->operand(0), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); + EXPECT_THAT(root->operand(1), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0), + sharding={{replicated}, {replicated}} + gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0, + sharding={devices=[2,1]0,1} + gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1, + sharding={devices=[2,1]0,1} + ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1), + sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + + EXPECT_THAT(root->operand(0), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); + EXPECT_THAT(root->operand(1), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); +} + +TEST_F(SpmdPartitioningTest, TiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), + op::GetTupleElement( + AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[9,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), + op::AfterAll(), op::AfterAll())))); + EXPECT_THAT( + root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter()))); + auto second_infeed = + AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter())); + EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), + op::Tuple(op::Pad(op::GetTupleElement(second_infeed), + op::Constant()), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}} + ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed), + index=0, sharding={{devices=[2,1]0,1}, {replicated}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"), + op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), op::AfterAll(), + op::AfterAll())))); + EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Infeed(op::Parameter()))); + auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"), + op::Infeed(op::Parameter())); + EXPECT_THAT( + root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Tuple(op::Tuple(op::Pad(op::GetTupleElement( + op::GetTupleElement(second_infeed)), + op::Constant()), + op::GetTupleElement( + op::GetTupleElement(second_infeed))), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1}, + to_apply=sum, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::AllReduce(op::Reduce( + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())), + op::Broadcast(op::Constant())), + AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant())), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledElementwise) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}), + sharding={replicated} + multiply = f32[3,3]{1,0} multiply(constant, constant.1), + sharding={devices=[2,1]0,1} + ROOT add = f32[3,3]{1,0} add(multiply, constant.1), + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Shape("f32[2,3]{1,0}"), + op::Add(op::Multiply( + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant()), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, TiledAllReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum, + replica_groups={}, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0)))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"), + op::Broadcast(op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2,3]{2,1,0}"), + op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), + op::Constant()))))); +} + +TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}), + sharding={devices=[2,1]0,1} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token.0 = token[] after-all() + data = f32[1024]{0} parameter(0), sharding={maximal device=0} + outfeed = token[] outfeed(data, token.0), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("token[]"), + op::Conditional( + op::Compare(op::PartitionId(), op::Constant()), + op::Tuple(op::Parameter(0), op::AfterAll()), + op::Tuple(op::Parameter(0), op::AfterAll())))); + + HloInstruction* root_b0 = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(root_b0, + AllOf(op::Shape("token[]"), + op::Outfeed(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1)))); + + HloInstruction* root_b1 = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll())); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={replicated} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow( + op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"), + op::Pad(op::Constant(), op::Constant())), + op::Multiply(op::Reshape(), op::Constant()), + op::Constant()), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1), + window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = + op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1), + window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum, + sharding={devices=[5,1]0,1,2,3,4} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/5)); + VLOG(1) << module->ToString(); + auto halo0 = AllOf(op::Shape("f32[1,2]"), + op::CollectivePermute(op::Slice(op::Parameter(0)))); + auto halo1 = + AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0))); + auto pre_mask = + AllOf(op::Shape("f32[4,2]"), + op::Slice(AllOf(op::Shape("f32[5,2]"), + op::Concatenate(halo0, halo1, op::Parameter(0))))); + auto masked = + op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())), + op::Broadcast(op::Constant())), + pre_mask, op::Broadcast(op::Constant())); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[9,2]{1,0} constant( + {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}), + sharding={devices=[3,1]0,1,2} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum, + sharding={devices=[3,1]0,1,2} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/3)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[7,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = AllOf( + op::Shape("f32[5,2]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(left_halo, sharded_input, right_halo), + op::Constant())), + op::Reshape(), op::Constant())); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0), + sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}} + infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,2,1,1]0,1,2,3} + constant = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant), + window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum, + sharding={devices=[2,2,1,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"), + op::GetTupleElement(op::Infeed())); + auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"), + op::Pad( + op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo), + op::Constant())), + op::Reshape(), op::Constant(), op::Constant(), op::Constant()); + auto dim0_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim0_masked = op::Select( + op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))), + dim0_pre_masking, op::Broadcast(op::Constant())); + auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked); + auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_right_halo = + AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"), + op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded, + dim1_right_halo), + op::Constant())), + op::Constant(), op::Reshape(), op::Constant(), op::Constant()); + auto dim1_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim1_masked = op::Select( + op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))), + dim1_pre_masking, op::Broadcast(op::Constant())); + auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"), + op::ReduceWindow(dim1_resharded, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,224,224,3]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]")); + auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)), + op::Shape("f32[128,112,224,3]")); + + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, reshard_lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[224,224,3,128] parameter(0) + %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=01fb_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[112,224,3,128]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[3,224,3,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[2,224,3,128]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +// (stride * per_shard_window_count) % dilation == 0 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + // There is no halo exchange, and because the last element in the shard is not + // needed (stride == 4), the LHS will be just a slice. + auto sliced_lhs = + AllOf(op::Slice(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant()))), + op::Shape("f32[128,3,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs), + op::Shape("f32[128,2,4,512]"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 1); +} + +// (stride * per_shard_window_count) % dilation != 0 but stride == 1 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationStride1LhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[128,4,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,7,512]")); + auto start_window = op::Multiply(op::Reshape(), op::Constant()); + auto start_input_element = op::Divide(start_window, op::Constant()); + auto dynamic_offset_for_padded_concat = op::Subtract( + op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + start_input_element)); + auto pre_masking = + AllOf(op::Shape("f32[128,5,7,512]"), + op::DynamicSlice( + AllOf(op::Shape("f32[128,6,7,512]"), + op::Pad(op::Concatenate(left_halo, lhs), op::Constant())), + op::Constant(), dynamic_offset_for_padded_concat, + op::Constant(), op::Constant())); + auto masked = op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)), + op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + auto dynamic_offset_on_output = op::Subtract( + start_window, op::Multiply(start_input_element, op::Constant())); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs), + op::Shape("f32[128,8,14,512]")), + op::Constant(), dynamic_offset_on_output, + op::Constant(), op::Constant()), + op::Shape("f32[128,7,14,512]"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[1,4]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto operand = AllOf(op::Copy(op::DynamicSlice( + op::Parameter(0), op::Constant(), op::Reshape())), + op::Shape("f32[11,1]")); + auto reshard_operand = op::Reshape(op::Transpose( + op::AllToAll(op::Reshape(op::Pad(operand, op::Constant()))))); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + reshard_operand, op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + + auto source_shard = + AllOf(op::Shape("f32[2,2]{1,0}"), + op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant())); + // Max halo size is the same as the shard size, so slice is not needed. + auto source_left_halo = op::CollectivePermute(source_shard); + auto required_source_shard_start = + op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto source_with_halo = op::DynamicSlice( + AllOf(op::Shape("f32[5,2]{1,0}"), + op::Pad(op::Concatenate(source_left_halo, source_shard), + op::Constant())), + op::Subtract(op::Constant(), + op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + required_source_shard_start)), + op::Constant()); + auto masked_source_with_halo = AllOf( + AllOf(op::Shape("f32[3,2]{1,0}")), + op::Select( + op::Compare( + op::Add(op::Iota(), op::Broadcast(required_source_shard_start)), + op::Broadcast(op::Constant())), + source_with_halo, op::Broadcast(op::Constant()))); + + auto data_shard = + AllOf(op::Shape("f32[3,4]{1,0}"), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant()))); + auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto required_data_start_on_padded = + op::Multiply(required_source_shard_start, op::Constant()); + auto left_halo_size = op::Subtract( + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()), + required_data_start_on_padded); + auto data_with_halo = + AllOf(op::Shape("f32[7,4]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[8,4]{1,0}"), + op::Pad(op::Concatenate(data_left_halo, data_shard, + data_right_halo), + op::Constant())), + op::Subtract(op::Constant(), left_halo_size), op::Constant())); + auto index_on_padded = + op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded)); + auto masked_data_with_halo = op::Select( + op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())), + op::Compare(index_on_padded, op::Broadcast(op::Constant()))), + data_with_halo, op::Broadcast(op::Constant())); + + EXPECT_THAT( + root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo, + masked_source_with_halo, + op::Constant()), + left_halo_size, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowReversal) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[5,128,64] parameter(0), sharding={devices=[2,1,1]0,1} + %rhs = f32[5,128,256] parameter(1), sharding={devices=[2,1,1]1,0} + ROOT %conv = f32[1,64,256] convolution(%lhs, %rhs), + window={size=5 rhs_reversal=1}, dim_labels=0fb_0io->0bf, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto lhs_masked = + AllOf(op::Shape("f32[3,128,64]"), op::Select(_, op::Parameter(0), _)); + auto rhs_left_padded = op::Slice(op::Concatenate( + op::CollectivePermute(op::Slice(op::Parameter(1))), op::Parameter(1))); + auto rhs_masked = + AllOf(op::Shape("f32[3,128,256]"), op::Select(_, rhs_left_padded, _)); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution(lhs_masked, rhs_masked)), + op::Shape("f32[1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,56,56,256]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]")); + auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all))); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,512] parameter(0) + %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,64] parameter(1) + %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, + dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,28,28,64]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]")); + auto reshard = op::Reshape(op::Transpose(all_to_all)); + + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)), + op::Shape("f32[1,1,512,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[32,16,28,64]")))), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[128,60,112,64]")))), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,1,7,512]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()), + op::Constant(), op::Subtract(), + op::Constant(), op::Constant()), + op::Shape("f32[128,10,14,512]")), + AllOf(op::Concatenate(left_halo, rhs), + op::Shape("f32[128,5,7,512]")))), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[32,16,28,128]")), + rhs)), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[128,117,224,3]")), + rhs)), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,14,512]")); + EXPECT_THAT( + root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice( + AllOf(op::Pad(op::Concatenate(lhs, right_halo), + op::Constant()), + op::Shape("f32[128,10,14,512]")), + op::Constant(), op::Reshape(), op::Constant(), + op::Constant()), + op::Shape("f32[128,9,14,512]")), + rhs)), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,257]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,116]")); + EXPECT_THAT(root, + AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Reshape())), + op::Shape("f32[14,129]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[14,58]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::DynamicUpdateSlice( + op::Broadcast(), param0, + op::Constant(), op::Multiply()), + param1, op::Constant(), op::Add())), + op::Shape("f32[14,374]")), + op::Constant(), op::Multiply()), + op::Shape("f32[14,187]"))); +} + +TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1} + %const = f32[] constant(0) + ROOT %pad = f32[128,17,257] pad(%param0, %const), padding=0_0x1_2x0_0, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()), + op::Shape("f32[128,17,129]"))); +} + +TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1} + %const = f32[] constant(0) + ROOT %pad = f32[14,259] pad(%param0, %const), padding=0_0x0_2, + sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Parameter(), op::Shape("f32[14,129]")); + auto after_halo_exchange = + AllOf(op::Shape("f32[14,130]"), + op::Concatenate(param0, op::CollectivePermute(op::Slice(param0)))); + auto pad = AllOf(op::Shape("f32[14,131]"), + op::Pad(after_halo_exchange, op::Constant())); + EXPECT_THAT(root, op::DynamicSlice(pad, op::Constant(), _)); +} + +TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimensionWithInteriorPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[7] parameter(0), sharding={devices=[2]0,1} + %param1 = f32[] parameter(1), sharding={replicated} + ROOT %pad = f32[22] pad(%param0, %param1), padding=2_1_2, + sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + + auto param0 = AllOf(op::Parameter(), op::Shape("f32[4]")); + auto after_halo_exchange = + AllOf(op::Shape("f32[4]"), + op::DynamicSlice( + AllOf(op::Shape("f32[5]"), + op::Concatenate(op::CollectivePermute(op::Slice(param0)), + param0)), + _)); + auto pad = op::Pad(after_halo_exchange, op::Parameter(1)); + EXPECT_THAT(root, op::DynamicSlice(pad, _)); +} + +TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[128,11,257] slice(%param0.copy), + slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]"))); +} + +TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[63,14,251] slice(%param0.copy), + slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf( + op::DynamicSlice( + AllOf(op::Concatenate( + param0, + AllOf(op::CollectivePermute(op::Slice(param0)), + op::Shape("f32[128,14,2]"))), + op::Shape("f32[128,14,131]")), + op::Constant(), op::Constant(), op::Add()), + op::Shape("f32[128,14,126]"))), + op::Shape("f32[63,14,126]"))); +} + +TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ge { + p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated} + bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + constant = s32[]{:T(256)} constant(0), sharding={replicated} + compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated} + constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated} + bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated} + bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated} + select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated} + p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated} + bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated} + bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated} + bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated} + select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated} + compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated} + compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated} + compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated} + p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated} + p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated} + compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated} + ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated} +} + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1} + %param1 = s32[128,14,257] parameter(1) + %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1} + ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)}) + sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true, + to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,7,257]")); + auto param1 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("s32[128,7,257]")); + EXPECT_THAT(root, AllOf(op::Sort(param0, param1), + op::Shape("(f32[128,7,257], s32[128,7,257])"))); +} + +TEST_F(SpmdPartitioningTest, PartitionCustomCall) { + const char* const hlo_string = R"( +HloModule cluster_2013453984438090939__.47 + +ENTRY %cluster_2013453984438090939__.47 + (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK" + %get-tuple-element = bf16[2,2000]{1,0} + get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call), + index=0, sharding={replicated} + %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0}, + s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated} + ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0} + %get-tuple-element.1), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto custom_call = FindInstruction(module.get(), "custom-call.1"); + EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000); +} + +TEST_F(SpmdPartitioningTest, PartitionSortInTopK) { + const char* const hlo_string = R"( +HloModule module + +%compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11: + s32[], p.1.rhs.12: s32[]) -> pred[] { + %p.1.lhs.11 = s32[] parameter(2) + %p.1.rhs.12 = s32[] parameter(3) + %p.0.lhs.9 = bf16[] parameter(0) + %convert.13 = f32[] convert(bf16[] %p.0.lhs.9) + %bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13) + %constant.20 = s32[] constant(0) + %compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20), + direction=LT + %constant.15 = u32[] constant(2147483647) + %bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13) + %subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17) + %bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18) + %select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[] + %bitcast-convert.16) + %p.0.rhs.10 = bf16[] parameter(1) + %convert.14 = f32[] convert(bf16[] %p.0.rhs.10) + %bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14) + %constant.28 = s32[] constant(0) + %compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28), + direction=LT + %constant.23 = u32[] constant(2147483647) + %bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14) + %subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25) + %bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26) + %select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[] + %bitcast-convert.24) + ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30), + direction=GT +} + +ENTRY entry + (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %iota.7 = s32[2,209664] iota(), iota_dimension=1, + metadata={op_type="TopKV2" op_name="TopKV2"} + %sort.32 = (bf16[2,209664], s32[2,209664]) + sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7), + dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.33 = bf16[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=0, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.34 = bf16[2,2000] slice(bf16[2,209664] + %get-tuple-element.33), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.35 = s32[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=1, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.36 = s32[2,2000] slice(s32[2,209664] + %get-tuple-element.35), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + ROOT %tuple.46 = (bf16[2,2000], s32[2,2000]) + tuple(bf16[2,2000] %slice.34, s32[2,2000] + %slice.36), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832); + auto final_sort = FindInstruction(module.get(), "sort.1"); + EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000); + EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000); +} + +TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) { + const char* const hlo_string = R"( +HloModule module + +%compare-greater-than.8 (p.0.lhs.2566: bf16[], + p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] { + %p.0.lhs.2566 = bf16[] parameter(0) + %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566) + %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164) + %constant.285 = s32[] constant(0) + %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285), + direction=LT + %constant.286 = u32[] constant(2147483647) + %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164) + %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49) + %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84) + %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50, + s32[] %bitcast-convert.48) + %p.0.rhs.2567 = bf16[] parameter(1) + %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567) + %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165) + %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285), + direction=LT + %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165) + %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52) + %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85) + %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53, + s32[] %bitcast-convert.51) + %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT + %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT + %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645), + direction=EQ + %p.1.lhs.2586 = s32[] parameter(2) + %p.1.rhs.2587 = s32[] parameter(3) + %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587), + direction=LT + ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647, + pred[] %compare.86) +} + +ENTRY entry + (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %iota.7 = s32[2,209664] iota(), iota_dimension=1, + metadata={op_type="TopKV2" op_name="TopKV2"} + %sort.32 = (bf16[2,209664], s32[2,209664]) + sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7), + dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.33 = bf16[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=0, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.34 = bf16[2,2000] slice(bf16[2,209664] + %get-tuple-element.33), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.35 = s32[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=1, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.36 = s32[2,2000] slice(s32[2,209664] + %get-tuple-element.35), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + ROOT %tuple.46 = (bf16[2,2000], s32[2,2000]) + tuple(bf16[2,2000] %slice.34, s32[2,2000] + %slice.36), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832); + auto final_sort = FindInstruction(module.get(), "sort.1"); + EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000); + EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000); +} + +TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) { + const char* const hlo_string = R"( +HloModule module + +%compare-greater-than.8 (p.0.lhs.2566: bf16[], + p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] { + %p.0.lhs.2566 = bf16[] parameter(0) + %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566) + %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164) + %constant.285 = s32[] constant(0) + %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285), + direction=LT + %constant.286 = u32[] constant(2147483647) + %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164) + %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49) + %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84) + %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50, + s32[] %bitcast-convert.48) + %p.0.rhs.2567 = bf16[] parameter(1) + %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567) + %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165) + %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285), + direction=LT + %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165) + %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52) + %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85) + %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53, + s32[] %bitcast-convert.51) + %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT + %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT + %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645), + direction=EQ + %p.1.lhs.2586 = s32[] parameter(2) + %p.1.rhs.2587 = s32[] parameter(3) + %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587), + direction=LT + ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647, + pred[] %compare.86) +} + +ENTRY entry { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %arg_tuple.2 = s32[2,209664] parameter(1) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %sort.32 = (bf16[2,209664], s32[2,209664]) + sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2), + dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.33 = bf16[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=0, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.34 = bf16[2,2000] slice(bf16[2,209664] + %get-tuple-element.33), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.35 = s32[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=1, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.36 = s32[2,2000] slice(s32[2,209664] + %get-tuple-element.35), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + ROOT %tuple.46 = (bf16[2,2000], s32[2,2000]) + tuple(bf16[2,2000] %slice.34, s32[2,2000] + %slice.36), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + std::cout << module->ToString(); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664); +} + +TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) { + const char* const hlo_string = R"( +HloModule module + +%compare-greater-than.8 (p.0.lhs.2566: bf16[], + p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] { + %p.0.lhs.2566 = bf16[] parameter(0) + %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566) + %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164) + %constant.285 = s32[] constant(0) + %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285), + direction=LT + %constant.286 = u32[] constant(2147483647) + %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164) + %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49) + %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84) + %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50, + s32[] %bitcast-convert.48) + %p.0.rhs.2567 = bf16[] parameter(1) + %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567) + %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165) + %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285), + direction=LT + %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165) + %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52) + %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85) + %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53, + s32[] %bitcast-convert.51) + %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT + %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT + %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645), + direction=EQ + %p.1.lhs.2586 = s32[] parameter(2) + %p.1.rhs.2587 = s32[] parameter(3) + %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587), + direction=LT + ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647, + pred[] %compare.86) +} + +ENTRY entry + (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1} + %iota.7 = s32[2,209664] iota(), iota_dimension=1, + metadata={op_type="TopKV2" op_name="TopKV2"} + %sort.32 = (bf16[2,209664], s32[2,209664]) + sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7), + dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.33 = bf16[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=0, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.34 = bf16[2,2000] slice(bf16[2,209664] + %get-tuple-element.33), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.35 = s32[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=1, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.36 = s32[2,2000] slice(s32[2,209664] + %get-tuple-element.35), slice={[0:2], [0:2000]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + ROOT %tuple.46 = (bf16[2,2000], s32[2,2000]) + tuple(bf16[2,2000] %slice.34, s32[2,2000] + %slice.36), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + std::cout << module->ToString(); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664); +} + +TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) { + const char* const hlo_string = R"( +HloModule module + +%compare-greater-than.8 (p.0.lhs.2566: bf16[], + p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] { + %p.0.lhs.2566 = bf16[] parameter(0) + %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566) + %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164) + %constant.285 = s32[] constant(0) + %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285), + direction=LT + %constant.286 = u32[] constant(2147483647) + %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164) + %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49) + %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84) + %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50, + s32[] %bitcast-convert.48) + %p.0.rhs.2567 = bf16[] parameter(1) + %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567) + %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165) + %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285), + direction=LT + %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165) + %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52) + %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85) + %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53, + s32[] %bitcast-convert.51) + %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT + %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT + %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645), + direction=EQ + %p.1.lhs.2586 = s32[] parameter(2) + %p.1.rhs.2587 = s32[] parameter(3) + %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587), + direction=LT + ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647, + pred[] %compare.86) +} + +ENTRY entry { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %iota.7 = s32[2,209664] iota(), iota_dimension=1, + metadata={op_type="TopKV2" op_name="TopKV2"} + %sort.32 = (bf16[2,209664], s32[2,209664]) + sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7), + dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.33 = bf16[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=0, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.34 = bf16[1,209664] slice(bf16[2,209664] + %get-tuple-element.33), slice={[0:1], [0:209664]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + %get-tuple-element.35 = s32[2,209664] + get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32), + index=1, metadata={op_type="TopKV2" op_name="TopKV2"} + %slice.36 = s32[1,209664] slice(s32[2,209664] + %get-tuple-element.35), slice={[0:1], [0:209664]}, + metadata={op_type="TopKV2" op_name="TopKV2"} + ROOT %tuple.46 = (bf16[1,209664], s32[1,209664]) + tuple(bf16[1,209664] %slice.34, s32[1,209664] + %slice.36), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664); +} + +TEST_F(SpmdPartitioningTest, ShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]"))); +} + +TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7} + ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy), + dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))), + op::Shape("f32[16,38,38,2]")); + EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); +} + +TEST_F(SpmdPartitioningTest, ShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1} + ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[19,38,324]")); + EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::DynamicSlice( + AllOf(op::Pad( + AllOf(op::Reshape(AllOf(op::AllReduce(), + op::Shape("f32[38,38,324]"))), + op::Shape("f32[38,38,4,81]")), + op::Constant()), + op::Shape("f32[38,38,4,82]")), + op::Constant(), op::Constant(), op::Constant(), op::Reshape()), + op::Shape("f32[38,38,4,41]"))); +} + +TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1} + ROOT %reshape = s32[3,2,1,14,5] reshape(%input), + sharding={devices=[1,1,1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto reshape = + AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]")); + auto halo = op::CollectivePermute(op::Slice(reshape)); + auto exchanged = + op::DynamicSlice(op::Concatenate(halo, reshape), _, _, _, _, _); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]"))); +} + +// Produces an invalid module after transformation. +TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[128,5,5,768] parameter(0) + %param0.copy = f32[128,5,5,768] copy(%param0), + sharding={devices=[1,4,1,1]0,1,2,3} + %constant.1 = f32[] constant(0), sharding={replicated} + ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1), + window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1}, + to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input_shard = op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())); + auto id_mul4_add1 = + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto id_mul5 = op::Multiply(op::Reshape(), op::Constant()); + auto id_mul5_add1_div3 = + op::Divide(op::Add(id_mul5, op::Constant()), op::Constant()); + auto before_masking = AllOf( + op::Shape("f32[128,3,5,768]"), + op::DynamicSlice( + AllOf( + op::Shape("f32[128,4,5,768]"), + op::Concatenate(op::CollectivePermute(input_shard), input_shard)), + op::Constant(), + op::Subtract(op::Constant(), + op::Subtract(id_mul4_add1, id_mul5_add1_div3)), + op::Constant(), op::Constant())); + auto masked = op::Select( + op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant())), + op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant()))), + before_masking, op::Broadcast(op::Constant())); + auto rw = AllOf(op::Shape("f32[128,7,17,768]"), + op::ReduceWindow(masked, op::Constant())); + auto final_slice_index = op::Subtract( + id_mul5, + op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant())); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("f32[128,5,17,768]"), + op::DynamicSlice(rw, op::Constant(), final_slice_index, + op::Constant(), op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,1,1,2]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[4,32,32,64]")); + + EXPECT_THAT(root, + AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1} + %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1} + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func, + sharding={{devices=[2]0,1}, {devices=[2]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3)), + op::Shape("(f32[14], s32[14])"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,2,1,1]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,16,32,128]")); + + EXPECT_THAT(root, + AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Reduce(param0, op::Constant())), + op::Shape("f32[128]")), + op::Reshape()), + op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=1, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = u32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("u32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, Conditional) { + const char* const hlo_string = R"( +HloModule module + +Negate { + x = f32[4,5] parameter(0), sharding={replicated} + ROOT negate = f32[4,5] negate(x), sharding={replicated} +} + +Identity { + y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1} + ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1} +} + +ENTRY entry { + %param.0 = pred[] parameter(0) + %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0} + %param.1 = f32[4,5] parameter(1) + %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated} + %param.2 = f32[4,5] parameter(2) + %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1} + ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy), + true_computation=Negate, false_computation=Identity, + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]"))); + auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]")); + auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[2,5]")); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2), + op::Shape("f32[2,5]"))); + + auto then_branch_root = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(then_branch_root, + AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(), + op::Constant()), + op::Shape("f32[2,5]"))); + + auto else_branch_root = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(else_branch_root, + AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param.0 = f32[32,128,384,64] parameter(0) + %param.0.copy = f32[32,128,384,64] copy(%param.0), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + %param.1 = f32[32,64,192,64] parameter(1) + %param.1.copy = f32[32,64,192,64] copy(%param.1), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy, + %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1}, + select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto source = AllOf( + op::Shape("f32[32,8,192,64]"), + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + auto data = AllOf( + op::Shape("f32[32,16,384,64]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + + EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant())); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, TiledDot) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]"))); +} + +TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]")), + op::Constant(), op::Reshape()), + op::Shape("f32[128,128]"))); +} + +TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,256,256] parameter(0) + %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[256,8,1] parameter(1) + %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy), + window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,128,256]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]"))); +} + +TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,64] parameter(0) + %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated} + %rhs = f32[39296,64] parameter(1) + %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant())), + op::Shape("f32[19648,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(), + op::Constant(), + op::Constant())), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,12,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + auto lhs_reshard = op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))); + EXPECT_THAT(root, + AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,24,64]")); + auto rhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,24,32,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,39296,32,64]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)), + op::Shape("f32[32,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,12,64,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,19648,64,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,12,64,128]")), + rhs), + op::Shape("f32[32,12,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT(root, + AllOf(op::Dot(lhs, AllOf(op::DynamicSlice( + rhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,19648,64,128]"))), + op::Shape("f32[32,24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,64,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,19648,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))), + op::Shape("f32[32,12,39295]"))); + auto while_loop = root->operand(0)->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)), + partial_output, op::Constant(), + op::Constant(), op::Reshape()), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,63,128] parameter(0) + %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39296,63,128] parameter(1) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,63,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,39296,32,128]")); + auto masked_rhs = + op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())); + EXPECT_THAT(root, + AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, masked_rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))); + auto while_loop = root->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot( + op::DynamicSlice( + op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), + op::Constant(), op::Constant(), op::Reshape(), op::Constant()), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::Add(op::GetTupleElement(op::Parameter(0)), partial_output), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2}, + to_apply=sum, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1}, + to_apply=sum, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %rhs = f32[32,39296,63,128] parameter(0) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1,1]0,1} + %add = f32[32,24,63,128] add(%broadcast, %broadcast), + sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, ReplicatedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={replicated} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]")); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Rng(), op::Broadcast(op::Constant()))), + op::Shape("s32[4]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]")); + EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select( + op::Broadcast(op::Compare()), rhs, + op::Broadcast(op::Constant())))), + op::Shape("s32[2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input.copy, %constant, %index), + dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)), + op::Shape("s32[64,2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + %update = s32[128,2] parameter(2) + %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1} + ROOT %dynamic-update-slice = s32[128,64] + dynamic-update-slice(%input.copy, %update.copy, %constant, %index), + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(), + op::Constant())), + op::Shape("s32[64,2]")); + EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(), + op::Parameter(1)), + op::Shape("s32[64,64]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughGather) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[3,5]"))); +} + +TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, + slice_sizes={1,9}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); + auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), + op::Shape("s32[2,3]")); + auto clamp = op::Clamp(min, op::Parameter(1), max); + auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); + auto mask = + op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + auto masked = + op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughScatter) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), + op::Parameter(2)), + op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + %updates = f32[2,3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto indices = op::Subtract( + op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)), + op::Shape("f32[9,9]"))); +} + +TEST_F(SpmdPartitioningTest, TiledReversePassthrough) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1}, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"), + op::Reverse(op::DynamicSlice( + op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[4] parameter(0), sharding={devices=[2]0,1} + ROOT reverse = f32[4] reverse(param), dimensions={0}, + sharding={devices=[2]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0)))); +} + +TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[4] parameter(0), sharding={devices=[2]0,1} + ROOT reverse = f32[4] reverse(param), dimensions={0}, + sharding={devices=[2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("f32[2]"), + op::Reverse(op::CollectivePermute(op::Parameter(0))))); +} + +TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[3] parameter(0), sharding={devices=[2]0,1} + ROOT reverse = f32[3] reverse(param), dimensions={0}, + sharding={devices=[2]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + auto halo_exchange_concat = + op::Concatenate(AllOf(op::Shape("f32[1]"), + op::CollectivePermute(op::Slice(op::Parameter(0)))), + op::Parameter(0)); + auto after_halo_exchange = op::Slice(halo_exchange_concat); + EXPECT_THAT(root, + AllOf(op::Shape("f32[2]"), op::Reverse(after_halo_exchange))); +} + +TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1} + to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated} + add = f32[4,2] add(to_shard, to_shard), sharding={replicated} + to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1} + ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + auto to_shard = op::Copy(op::Parameter(0)); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"), + op::Multiply(op::Copy(op::Add(to_shard, to_shard)), + op::Parameter(0)))); +} + +} // namespace +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc new file mode 100644 index 00000000000..df7597628af --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -0,0 +1,874 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace spmd { + +bool HasReplicatedSharding(const HloSharding& sharding) { + if (sharding.IsTuple()) { + return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding); + } + return sharding.IsReplicated(); +} + +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements.push_back( + CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b)); + } + return b->AddInstruction(HloInstruction::CreateTuple(elements)); + } + + if (shape.IsToken()) { + return b->AddInstruction(HloInstruction::CreateToken()); + } + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); +} + +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + if (type == PRED) { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); + } else { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + } + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))) { + return false; + } + } + } + + if (sharding.IsTileMaximal()) { + return sharding.IsReplicated(); + } + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { + return false; + } + } + return true; +} + +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back( + MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + return sharding.TileShape(shape); +} + +int64 ShapeSizeInBytes(const Shape& shape) { + return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) * + ShapeUtil::ElementsIn(shape); +} + +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back(MakeNonPaddedShapeForGivenPartition( + ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}), partition_id)); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + + auto partition_shape = shape; + std::vector tile_offset = + sharding.TileOffsetForDevice(shape, partition_id); + std::vector tile_limit = + sharding.TileLimitForDevice(shape, partition_id); + for (int64 i = 0; i < tile_offset.size(); ++i) { + if (sharding.UsesDevice(partition_id)) { + partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]); + } else { + partition_shape.set_dimensions(i, 0); + } + } + return partition_shape; +} + +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b) { + CHECK(!shape.IsTuple()); + + Array2D offset_array( + {sharding.tile_assignment().num_elements(), shape.rank()}); + offset_array.Each([&](int64 i, int64 j, int32* value) { + *value = sharding.TileOffsetForDevice(shape, i)[j]; + }); + auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(offset_array))); + std::vector offsets; + for (int64 i = 0; i < shape.rank(); ++i) { + if (sharding.tile_assignment().dim(i) == 1) { + offsets.push_back(b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); + } else { + auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1, 1}), offset_table, + {partition_id, b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(i)))}, + {1, 1})); + offsets.push_back(b->AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index))); + } + } + return offsets; +} + +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { + CHECK(!sharding.IsTileMaximal()); + auto table_shape = + ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions()); + return MakePartitionOffsets(table_shape, sharding, partition_id, b); +} + +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, HloComputation* computation) { + CHECK(b == nullptr || computation == nullptr); + if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) { + return hlo; + } + PaddingConfig padding_config; + for (int64 i = 0; i < padded_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + hlo->shape().dimensions(i)); + } + auto add_hlo = [&](std::unique_ptr to_add) { + if (b == nullptr) { + return computation->AddInstruction(std::move(to_add)); + } + return b->AddInstruction(std::move(to_add)); + }; + auto zero = add_hlo(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + return add_hlo( + HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config)); +} + +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return base_shape; + } + if (EvenlyPartitions(base_shape, sharding)) { + return base_shape; + } + auto shard_shape = MakePartitionedShape(base_shape, sharding); + Shape padded_base_shape = base_shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + return padded_base_shape; +} + +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { + auto padded_base_shape = + GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding); + if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) { + return hlo; + } + return PadToShape(hlo, padded_base_shape, b); +} + +absl::optional UniqueTiledDim(const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return absl::nullopt; + } + int64 dim = -1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) > 1) { + if (dim != -1) { + return absl::nullopt; + } + dim = i; + } + } + CHECK_NE(dim, -1); + return dim; +} + +MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation( + int64 multiplier, int64 offset, int64 divisor) + : multiplier_(multiplier), offset_(offset), divisor_(divisor) { + CHECK_GT(divisor_, 0); + Simplify(); +} + +OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-( + const MultiplyAddDivideOffsetCalculation& other) const { + if (divisor_ == 1 && other.divisor_ == 1) { + return OffsetCalculation(MultiplyAddDivideOffsetCalculation( + multiplier_ - other.multiplier_, offset_ - other.offset_, 1)); + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +void MultiplyAddDivideOffsetCalculation::Simplify() { + // We could simplify the calculation when multiplier is a multiple of + // divisor_. However, when offset_ is not a multiple of divisor_, we must + // make sure that offset_ and multiplier_ are both non-negative or both + // non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1. + if (divisor_ != 1 && multiplier_ % divisor_ == 0 && + (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) { + multiplier_ /= divisor_; + offset_ /= divisor_; + divisor_ = 1; + } +} + +int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const { + return (shard_ordinal * multiplier_ + offset_) / divisor_; +} + +HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate( + HloInstruction* shard_ordinal, SpmdBuilder* b) const { + auto scalar_shape = ShapeUtil::MakeShape(S32, {}); + if (multiplier_ == 0) { + return b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(offset_ / divisor_))); + } + HloInstruction* result = shard_ordinal; + if (multiplier_ != 1) { + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMultiply, shard_ordinal, + b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(multiplier_))))); + } + if (offset_ != 0) { + auto offset = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(offset_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, result, offset)); + } + if (divisor_ != 1) { + auto divisor = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(divisor_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kDivide, result, divisor)); + } + return result; +} + +int64 MultiplyAddDivideOffsetCalculation::MaxInRange( + int64 start_ordinal, int64 limit_ordinal) const { + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +OffsetCalculation& OffsetCalculation::operator=( + const OffsetCalculation& other) { + opcode_ = other.opcode_; + copy_from_ = other.copy_from_; + if (opcode_ != HloOpcode::kCopy) { + lhs_ = absl::make_unique(*other.lhs_); + rhs_ = absl::make_unique(*other.rhs_); + } + return *this; +} + +bool OffsetCalculation::IsConstant() const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.IsConstant(); + } + if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) { + return true; + } + return lhs_->IsConstant() && rhs_->IsConstant(); +} + +OffsetCalculation OffsetCalculation::operator-( + const OffsetCalculation& other) const { + if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) { + return copy_from_ - other.copy_from_; + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +bool OffsetCalculation::operator==(const OffsetCalculation& other) const { + if (opcode_ != other.opcode_) { + return false; + } + if (opcode_ == HloOpcode::kCopy) { + return copy_from_ == other.copy_from_; + } + return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_; +} + +int64 OffsetCalculation::Calculate(int64 shard_ordinal) const { + switch (opcode_) { + case HloOpcode::kCopy: + return copy_from_.Calculate(shard_ordinal); + case HloOpcode::kSubtract: + return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal); + case HloOpcode::kMultiply: + return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal); + default: + LOG(FATAL) << "Should not happen"; + } +} + +HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.Calculate(shard_ordinal, b); + } + auto lhs = lhs_->Calculate(shard_ordinal, b); + auto rhs = rhs_->Calculate(shard_ordinal, b); + return b->AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs)); +} + +int64 OffsetCalculation::MaxInRange(int64 start_ordinal, + int64 limit_ordinal) const { + if (IsConstant()) { + return Calculate(start_ordinal); + } + if (opcode_ == HloOpcode::kCopy) { + return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1)); + } + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + int64 input_shard_size = hlo->shape().dimensions(dim); + int64 shard_count = target.tile_assignment().dim(dim); + + std::vector concat_pieces; + + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + for (int64 i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1; i >= 0; + --i) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > i) { + std::vector source_indices(indices.begin(), indices.end()); + source_indices[dim] -= i + 1; + source_target_pairs.emplace_back( + target.tile_assignment()(source_indices), device); + } + }); + int64 halo_size = + std::min(max_left_halo_size - input_shard_size * i, input_shard_size); + auto halo_shape = hlo->shape(); + auto source_halo_slice = hlo; + if (halo_size != hlo->shape().dimensions(dim)) { + halo_shape.set_dimensions(dim, halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + halo_start_indices[dim] = hlo->shape().dimensions(dim) - halo_size; + std::vector halo_slice_strides(halo_shape.rank(), 1); + source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice( + halo_shape, hlo, halo_start_indices, hlo->shape().dimensions(), + halo_slice_strides)); + } + auto left_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(left_halo); + } + + concat_pieces.push_back(hlo); + + // Right halo. + int64 max_right_halo_size = + right_halo_size_function.MaxInRange(0, shard_count - 1); + for (int64 i = 0; i < CeilOfRatio(max_right_halo_size, input_shard_size); + ++i) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > i) { + std::vector target_indices(indices.begin(), indices.end()); + target_indices[dim] -= i + 1; + source_target_pairs.emplace_back( + device, target.tile_assignment()(target_indices)); + } + }); + int64 halo_size = + std::min(max_right_halo_size - input_shard_size * i, input_shard_size); + auto halo_shape = hlo->shape(); + HloInstruction* source_halo_slice = hlo; + if (halo_size != halo_shape.dimensions(dim)) { + halo_shape.set_dimensions(dim, halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + std::vector halo_slice_strides(halo_shape.rank(), 1); + source_halo_slice = b->AddInstruction(HloInstruction::CreateSlice( + halo_shape, hlo, halo_start_indices, halo_shape.dimensions(), + halo_slice_strides)); + } + auto right_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(right_halo); + } + + auto concat = hlo; + // Concat with halos/padding. + if (concat_pieces.size() > 1) { + auto concat_shape = hlo->shape(); + int64 concat_dim_size = 0; + for (auto piece : concat_pieces) { + concat_dim_size += piece->shape().dimensions(dim); + } + concat_shape.set_dimensions(dim, concat_dim_size); + concat = b->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim)); + } + + return concat; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + CHECK(left_halo_size_functions.size() == hlo->shape().rank()); + CHECK(right_halo_size_functions.size() == hlo->shape().rank()); + + HloInstruction* visiting_hlo = hlo; + for (int dim = 0; dim < hlo->shape().rank(); ++dim) { + auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim], + right_halo_size_functions[dim], dim, target, + collective_ops_creator, next_channel_id, b); + if (!concat) { + return absl::nullopt; + } + visiting_hlo = *concat; + } + return visiting_hlo; +} + +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) { + auto halo_exchange_result = + ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim, + target, collective_ops_creator, next_channel_id, b); + if (!halo_exchange_result) { + return absl::nullopt; + } + auto concat = *halo_exchange_result; + int64 shard_count = target.tile_assignment().dim(dim); + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + + // Now we determine if we need extra padding after the concat. + // + // The max of halo size or the first shard's explicit left padding. + int64 max_left_halo_or_padding_size = + std::max(std::max(int64{0}, max_left_halo_size), + explicit_left_padding_on_full_shape); + // The calculation that returns the dynamic slice index for a shard on the + // padded concat, which is the difference between + // max_left_halo_or_padding_size and its left halo size. + auto start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, max_left_halo_or_padding_size, 1)) - + left_halo_size_function; + + // See if we need to pad the concat before dynamic slice. + int64 extra_left_padding = + std::max(int64{0}, max_left_halo_or_padding_size - + std::max(int64{0}, max_left_halo_size)); + int64 extra_right_padding = + start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) + + shard_size_with_halo - concat->shape().dimensions(dim) - + extra_left_padding; + extra_right_padding = std::max(int64{0}, extra_right_padding); + if (extra_left_padding > 0 || extra_right_padding > 0) { + PaddingConfig padding_config; + auto padded_concat_shape = concat->shape(); + for (int64 i = 0; i < base_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + if (i != dim) { + continue; + } + padding_config_dim->set_edge_padding_low(extra_left_padding); + padding_config_dim->set_edge_padding_high(extra_right_padding); + padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) + + extra_left_padding + + extra_right_padding); + } + concat = b->AddInstruction(HloInstruction::CreatePad( + padded_concat_shape, concat, pad_value, padding_config)); + } + + auto valid_slice = concat; + if (shard_size_with_halo != concat->shape().dimensions(dim)) { + // Concat is bigger than the shard shape, so we need a dynamic slice. + CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(dim, shard_size_with_halo); + + if (left_halo_size_function.IsConstant() && + left_halo_size_function.Calculate(0) == + explicit_left_padding_on_full_shape) { + std::vector start_indices(slice_shape.rank(), 0); + std::vector strides(slice_shape.rank(), 1); + valid_slice = b->AddInstruction( + HloInstruction::CreateSlice(slice_shape, concat, start_indices, + slice_shape.dimensions(), strides)); + } else { + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(base_shape.rank(), zero); + slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( + partition_ordinal, b); + valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); + } + } + + if (!mask_invalid_region) { + return valid_slice; + } + + int64 total_right_padding = padded_full_shape_size - + base_shape.dimensions(dim) - + explicit_left_padding_on_full_shape; + // Mask off garbage data due to uneven partition or low/high padding. + if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) { + auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32); + auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index_in_padded_shape = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, offset_on_padded_shape, {})); + auto index_in_padded_shape = b->AddInstruction( + HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota, + broadcast_start_index_in_padded_shape)); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + std::vector predicates; + if (explicit_left_padding_on_full_shape > 0) { + auto valid_index_start = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_start, + ComparisonDirection::kGe))); + } + if (total_right_padding > 0) { + auto valid_index_limit = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + base_shape.dimensions(dim) + + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_limit, + ComparisonDirection::kLt))); + } + CHECK(!predicates.empty()); + auto is_valid = + predicates.size() == 2 + ? b->AddInstruction(HloInstruction::CreateBinary( + mask_shape, HloOpcode::kAnd, predicates[0], predicates[1])) + : predicates[0]; + auto masking_value = b->AddInstruction( + HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {})); + valid_slice = b->AddInstruction( + HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect, + is_valid, valid_slice, masking_value)); + } + return valid_slice; +} + +HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original, + absl::Span dims) { + if (original.sharding().IsTileMaximal()) { + return original.hlo(); + } + // Create a window config to halo exchange for unevenly partitioned reverse + // dimensions. + Window window; + for (int64 i = 0; i < original.base_shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + int64 low_padding = 0; + if (absl::c_linear_search(dims, i)) { + low_padding = + RoundUpToNearest(original.base_shape().dimensions(i), + original.sharding().tile_assignment().dim(i)) - + original.base_shape().dimensions(i); + } + dim->set_padding_low(low_padding); + dim->set_padding_high(0); + dim->set_base_dilation(1); + } + + auto reshard_window = original.ReshardAsWindowedInput( + window, original.sharding(), + CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}), + original.state().b), + /*mask_invalid_region=*/false); + if (!reshard_window.has_value()) { + return nullptr; + } + CHECK(!reshard_window->dynamic_slice_index_on_output.has_value()); + return reshard_window->sharded_input; +} + +bool IsNanSafeGt(HloComputation* comp) { + namespace m = match; + auto match_bitcast_f32 = [](int64 parameter_number) { + auto param = m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(F32)); + auto param_s32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32)); + auto param_u32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32)); + return m::Select( + m::Lt(param_s32, m::ConstantScalar(0)), + m::BitcastConvert( + m::Subtract(m::ConstantScalar(std::numeric_limits::max()), + param_u32)) + .WithShape(m::Shape().WithElementType(S32)), + param_s32); + }; + auto match_bitcast_bf16 = [](int64 parameter_number) { + auto param = m::Convert(m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(BF16))) + .WithShape(m::Shape().WithElementType(F32)); + auto param_s32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32)); + auto param_u32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32)); + return m::Select( + m::Lt(param_s32, m::ConstantScalar(0)), + m::BitcastConvert( + m::Subtract(m::ConstantScalar(std::numeric_limits::max()), + param_u32)) + .WithShape(m::Shape().WithElementType(S32)), + param_s32); + }; + // If root instruction is kSelect and compares indices if values are equal. + if (comp->root_instruction()->opcode() == HloOpcode::kSelect) { + return Match(comp->root_instruction()->operand(2), + m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) || + Match(comp->root_instruction()->operand(2), + m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1))); + } + return Match(comp->root_instruction(), + m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) || + Match(comp->root_instruction(), + m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1))); +} + +absl::optional GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo) { + HloSortInstruction* sort = DynCast(hlo); + if (sort == nullptr || sort->operand_count() != 2) { + return absl::nullopt; + } + if (!IsNanSafeGt(sort->to_apply())) { + return absl::nullopt; + } + HloInstruction* data = sort->mutable_operand(0); + HloIotaInstruction* iota = + DynCast(sort->mutable_operand(1)); + const PrimitiveType element_type = data->shape().element_type(); + if (iota == nullptr || iota->shape().element_type() != S32 || + iota->opcode() != HloOpcode::kIota || + iota->iota_dimension() != sort->sort_dimension()) { + return absl::nullopt; + } + + const int64 sort_dim = sort->sort_dimension(); + + if (element_type != F32 && element_type != BF16 && element_type != S32 && + element_type != U32) { + return absl::nullopt; + } + + bool supported = true; + absl::optional k; + for (HloInstruction* gte : sort->users()) { + if (gte->opcode() != HloOpcode::kGetTupleElement) { + supported = false; + break; + } + + const HloInstruction* slice = gte->users()[0]; + if (slice->opcode() != HloOpcode::kSlice) { + // Non-slice user means we are not doing a TopK + supported = false; + break; + } + if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) || + absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) { + // Strided slice or slicing at the beginning isn't supported. + supported = false; + break; + } + for (int64 dim = 0; dim < data->shape().dimensions_size(); dim++) { + if (dim == sort_dim) { + continue; + } + if (slice->slice_limits(dim) != + slice->operand(0)->shape().dimensions(dim)) { + // Slicing along the other dimension isn't supported. + supported = false; + break; + } + } + if (!k.has_value()) { + k = slice->slice_limits(sort_dim); + } else if (k != slice->slice_limits(sort_dim)) { + // Different k for the different operands isn't supported. + supported = false; + break; + } + } + if (k == absl::nullopt || !supported) { + return absl::nullopt; + } + + // Only support when sort dim is sharded. + if (!data->has_sharding()) { + return absl::nullopt; + } + const HloSharding& sharding = sort->operand(0)->sharding(); + + if (sharding.IsTileMaximal()) { + return absl::nullopt; + } + + // Check if partitioned at sort dimension. + for (int64 dim : sort->dimensions()) { + if (sharding.tile_assignment().dim(dim) > 1) { + if (dim != sort_dim) { + return absl::nullopt; + } + } + } + + // Checks if partition size is smaller than k. + const int64 shard_count = sharding.tile_assignment().dim(sort_dim); + + if (shard_count <= 1) { + return absl::nullopt; + } + + const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim); + const int64 per_partition_size = CeilOfRatio(input_size, shard_count); + + if (k.value() >= per_partition_size) { + return absl::nullopt; + } + + return k; +} + +// Slice first k elements from sort_dim. +HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder, + int64 slice_dim, int64 k) { + const Shape& hlo_shape = hlo->shape(); + auto hlo_dims = hlo_shape.dimensions(); + std::vector start_indices(hlo_shape.dimensions_size(), 0); + std::vector limit_indices(hlo_dims.begin(), hlo_dims.end()); + std::vector strides(hlo_shape.dimensions_size(), 1); + limit_indices[slice_dim] = k; + auto output_shape = hlo_shape; + output_shape.set_dimensions(slice_dim, k); + return builder->AddInstruction(HloInstruction::CreateSlice( + output_shape, hlo, start_indices, limit_indices, strides)); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h new file mode 100644 index 00000000000..5f245667970 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -0,0 +1,268 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +namespace xla { +namespace spmd { + +// Returns true if the given sharding contains any replicated sharding. +bool HasReplicatedSharding(const HloSharding& sharding); + +// Creates zero value instructions of the given shape. +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); + +template +HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, + SpmdBuilder* b) { + auto literal = LiteralUtil::CreateR0(value) + .ConvertToShape(ShapeUtil::MakeShape(type, {})) + .ValueOrDie(); + return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal))); +} + +inline HloInstruction* CreateFirstWithType(PrimitiveType type, SpmdBuilder* b) { + if (type == F32) { + auto float_pad_value = std::numeric_limits::quiet_NaN(); + return CreateR0WithType(type, -float_pad_value, b); + } + auto literal = LiteralUtil::MinValue(type); + return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal))); +} + +inline HloInstruction* CreateLastWithType(PrimitiveType type, SpmdBuilder* b) { + if (type == F32) { + auto float_pad_value = std::numeric_limits::quiet_NaN(); + return CreateR0WithType(type, float_pad_value, b); + } + auto literal = LiteralUtil::MaxValue(type); + return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal))); +} + +// Create a binary add computation of the given type and add to the module. +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module); + +// Returns true if the shape can be evenly partitioned for the given sharding. +// All tile sharded dimensions should be evenly divisible and there should be no +// single-device sharding. Replicate sharding is considered even partition. +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding); + +// Returns the shard shape of the given shape when it is partitioned for the +// target sharding. +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding); + +// Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout +// since this can be before layout assignment. +int64 ShapeSizeInBytes(const Shape& shape); + +// Returns the shard shape for a partition without padding due to uneven +// sharding. +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id); + +// Generates the HLO instructions that represent the dimension offsets on any +// device. The size of the returned vector is the rank of the given shape. +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b); + +// Returns the offsets of the partition in the tile assignment. +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b); + +// Pads hlo to the desired shape using high padding. Either a builder or a +// computation needs to be supplied, but not both. +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, + HloComputation* computation = nullptr); + +// Returns the padded shape when combining all partitions. +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding); + +// Pads the HLO (with base shape) for uneven tiled partition to make it evenly +// partitionable. +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b); + +// Returns the index of the unique tile dimension. Returns absl::nullopt if the +// given sharding is not tiled or tiled along multiple dimensions. +absl::optional UniqueTiledDim(const HloSharding& sharding); + +// Utilities for symbolic offset calculation and halo exchange. +class OffsetCalculation; + +// Represents a calculation over integers: +// (shard_ordinal * multiplier + offset) / divisor +class MultiplyAddDivideOffsetCalculation { + public: + MultiplyAddDivideOffsetCalculation() + : multiplier_(0), offset_(0), divisor_(1) {} + MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset, + int64 divisor); + + OffsetCalculation operator-( + const MultiplyAddDivideOffsetCalculation& other) const; + + bool operator==(const MultiplyAddDivideOffsetCalculation& other) const { + return multiplier_ == other.multiplier_ && offset_ == other.offset_ && + divisor_ == other.divisor_; + } + + bool IsConstant() const { return multiplier_ == 0; } + void Simplify(); + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + int64 multiplier_; + int64 offset_; + int64 divisor_; +}; + +// Represents a calculation over integers based on results of other calculations +// defined by an opcode. If the opcode is kCopy, it simply wraps an +// MultiplyAddDivideOffsetCalculation. +class OffsetCalculation { + public: + OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {} + explicit OffsetCalculation( + const MultiplyAddDivideOffsetCalculation& copy_from) + : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {} + OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; } + OffsetCalculation(HloOpcode opcode, + const MultiplyAddDivideOffsetCalculation& lhs, + const MultiplyAddDivideOffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs, + const OffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + + OffsetCalculation& operator=(const OffsetCalculation& other); + + // Returns whether the calculation returns the same value for all shards. This + // is conservative and could return false even if it is actually constant. + bool IsConstant() const; + + OffsetCalculation operator-(const OffsetCalculation& other) const; + bool operator==(const OffsetCalculation& other) const; + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + HloOpcode opcode_; + std::unique_ptr lhs_; + std::unique_ptr rhs_; + MultiplyAddDivideOffsetCalculation copy_from_; +}; + +// Performs halo exchange on the given dimension based on the provided +// left/right halo size functions. Returns nullopt if the halo is beyond the +// direct neighbor of the shard. +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the +// dimensions fails to exchange halo (halo is beyond the neighbor shard). +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchanges halos and performs pad/dynamic-slice on the concatenated data such +// that the result starts with the first needed element on each shard. It also +// masks off invalid data due to padding. +// Arguments: +// hlo: the HLO op before halo exchange +// explicit_left_padding_on_full_shape: the amount of left padding to be added +// explicitly by this function on the base shape before partitioning. Without +// base dilation, this is usually set to the window's padding_low so that the +// sharded op do not need to add padding_low on the window; however, with base +// dilation, this could only be set to a custom size. +// padded_full_shape_size: the size of the padded full shape on the given +// dimension, which includes explicit_left_padding_on_full_shape and required +// right padding to make the shape evenly shardable. +// shard_size_with_halo: the shard size on the dimension after halo exchange. +// If different shards have different sizes, use the maximum size. +// offset_on_padded_shape: the offset HLO (S32) that represents the start of +// each shard on the padded full shape. +// pad_value: the padding value used on the full shape. +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true); + +// Uses halo exchange to change from right-padding to left-padding for uneven +// tiled sharding on the given dimensions. Tiled sharding always pads uneven +// partitioned data on the right, but we need to swap it to the left for +// kReverse or kConvolution with window reversal. +HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original, + absl::Span dims); + +// Check if the computation is GT comparison and safe for NaNs. +bool IsNanSafeGt(HloComputation* computation); + +// Return k in TopK when input value is parttioned in the sort dimension. +absl::optional GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo); + +// Slices the first k elements at slice dimension. +HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder, + int64 slice_dim, int64 k); + +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 2d33184b7d0..1111811d3a3 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -300,7 +300,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { - VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + VLOG(2) << "HLO module before WhileLoopInvariantCodeMotion:"; XLA_VLOG_LINES(2, module->ToString()); bool changed = false; @@ -332,10 +332,10 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { } if (changed) { - VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + VLOG(2) << "HLO module after WhileLoopInvariantCodeMotion:"; XLA_VLOG_LINES(2, module->ToString()); } else { - VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + VLOG(2) << "HLO module unchanged after WhileLoopInvariantCodeMotion"; } return changed; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 22ee5a16a30..52cbb8f95ac 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" @@ -150,6 +151,19 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualStructure(const Shape& lhs, + const Shape& rhs) { + bool equal = true; + ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(rhs, index); + }); + ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(lhs, index); + }); + + return equal; +} + /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -261,6 +275,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ValidateShape(*shape); } +/* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) { + Shape result = original; + result.clear_dynamic_dimensions(); + return result; +} + /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); @@ -626,8 +646,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); } else if (shape.IsArray()) { - int64 byte_size = ByteSizeOfElements(shape); - return byte_size; + return ByteSizeOfElements(shape); } else if (shape.element_type() == TOKEN) { return 0; } else if (shape.element_type() == OPAQUE_TYPE) { @@ -1441,6 +1460,19 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( return shape; } +/* static */ bool ShapeUtil::DynamicShapeIsCompatible( + const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { + if (dynamic_shape.rank() != bounded_shape.rank()) { + return false; + } + for (int64 i = 0; i < dynamic_shape.rank(); ++i) { + if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { + return false; + } + } + return true; +} + /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { CHECK(shape.IsArray()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 7e05e17865d..dde56587482 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -298,6 +298,16 @@ class ShapeUtil { // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Two shapes have same structure if all subshape indices of lhs are presented + // on rhs and vice versa. + // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to + // (S32, (F32[3], S32[2])) as their structures are both (,(,)) + // + // In contrast, (F32, (F32, F32)) is structurally different from + // ((F32, F32), F32) as the former has structure (,(,)) while the latter has + // ((,),) + static bool EqualStructure(const Shape& lhs, const Shape& rhs); + // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -339,6 +349,9 @@ class ShapeUtil { // element type changed to type. static Shape ChangeElementType(const Shape& original, PrimitiveType type); + // Retursn a shape with same dimensions but with all dimensions set to static. + static Shape MakeStaticShape(const Shape& original); + // Creates a tuple shape from a slice of element shapes within the tuple. static Shape MakeTupleShape(absl::Span shapes); @@ -643,12 +656,16 @@ class ShapeUtil { static Shape FilterDimensions(const std::function& p, Shape shape); - // Iterates through all the shape indexes, in minor to major order, starting - // from the base indexes, incrementing by the incr steps, up to count - // (index[i] < base[i] + count[i]), and calls the visitor_function with the - // current index. - // The visitor_function visitor function should return true if it wants to - // continue, or false otherwise. + // Returns true if `dynamic_shape` has dimensions that are less-equal to the + // "bounded_shape". + static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, + const xla::Shape& bounded_shape); + + // Iterates through all the shape indexes, in minor to major order, + // starting from the base indexes, incrementing by the incr steps, up to + // count (index[i] < base[i] + count[i]), and calls the visitor_function + // with the current index. The visitor_function visitor function should + // return true if it wants to continue, or false otherwise. // // visitor_function must be a callable of type // StatusOr(absl::Span) or compatible. diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 5b83186ffa4..790497f888e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,6 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_cpu_enable_fast_min_max(!disabled); opts->set_xla_gpu_enable_fast_min_max(!disabled); } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6d64cb0a510..26cb25acbfe 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -416,6 +416,10 @@ XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl(); } #endif XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl(); } XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl>(); } +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128 +XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl>(); } +#endif XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl(); } INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 8eed609a134..7b64be5597b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -165,6 +165,16 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { return precision_config; } +void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) { + options->set_xla_cpu_enable_fast_math(true); + options->set_xla_gpu_enable_fast_min_max(true); + options->set_xla_cpu_enable_fast_min_max(true); + options->set_xla_cpu_fast_math_honor_nans(false); + options->set_xla_cpu_fast_math_honor_infs(false); + options->set_xla_cpu_fast_math_honor_functions(false); + options->set_xla_cpu_fast_math_honor_division(false); +} + DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index d05776a0cb9..85b1876dd3c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -100,6 +100,10 @@ class HloTestBase : public ::testing::Test { static PrecisionConfig DefaultPrecisionConfig(int operands); + // Sets most fath math options to be enabled to model the fast math flags + // generally used for CPU:AOT compilation. + static void SetAotFastMathDebugOptions(DebugOptions* options); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3407a68f709..40e226f9902 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -310,8 +310,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { XlaBuilder builder(TestName()); - mutable_debug_options()->set_xla_cpu_enable_fast_math(false); - mutable_debug_options()->set_xla_gpu_enable_fast_min_max(false); + SetFastMathDisabled(true); auto low = ConstantR1(&builder, {NAN, 1, 1}); auto high = ConstantR1(&builder, {3, NAN, 3}); auto x = ConstantR1(&builder, {2, 2, NAN}); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 5a482305513..d575bbb1f3e 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -863,7 +863,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Starts = iteration * 2; auto starts = Mul(iteration, ConstantR0(&builder, 2)); // UpdateSlice. - auto out1 = DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, {starts}); Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc index 4f8a6b43314..b6c62beff74 100644 --- a/tensorflow/compiler/xla/tools/interactive_graphviz.cc +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -112,8 +112,7 @@ constexpr int64 kDefaultMaxNumNodesInAllPaths = 100; using absl::EqualsIgnoreCase; -// A global control for whether backend configuration display is enabled. -bool show_backend_config = true; +HloRenderOptions hlo_render_options; HloInstruction* FindInstruction(const HloModule& module, string node_name) { if (absl::StartsWith(node_name, "%")) { @@ -160,6 +159,8 @@ void DoHelpCommand() { Renders all nodes in . backend_config [on|off] Controls whether backend operation configuration information is printed. + show_fusion_subcomputations [on|off] + Controls whether fusion subcomputations are shown. list [name|op_name|op_type] Lists all instructions whose name, metadata op_name, or metadata op_type contains as a substring. @@ -182,15 +183,32 @@ void DoHelpCommand() { // Turn metadata-printing on or off. void DoBackendConfigCommand(const std::vector& tokens) { if (tokens.size() == 2 && tokens[1] == "on") { - show_backend_config = true; + hlo_render_options.show_backend_config = true; } else if (tokens.size() == 2 && tokens[1] == "off") { - show_backend_config = false; + hlo_render_options.show_backend_config = false; } else if (tokens.size() != 1) { std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)" << std::endl; } std::cout << "Backend configuration display " - << (show_backend_config ? "ON" : "OFF") << std::endl; + << (hlo_render_options.show_backend_config ? "ON" : "OFF") + << std::endl; +} + +// Turn fusion computation display on or off. +void DoShowFusionSubcomputationsCommand(const std::vector& tokens) { + if (tokens.size() == 2 && tokens[1] == "on") { + hlo_render_options.show_fusion_subcomputations = true; + } else if (tokens.size() == 2 && tokens[1] == "off") { + hlo_render_options.show_fusion_subcomputations = false; + } else if (tokens.size() != 1) { + std::cerr << "(Illegal show_fusion_subcomputations value. Use either " + "'on' or 'off'.)" + << std::endl; + } + std::cout << "Fusion subcomputations display " + << (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF") + << std::endl; } // List all computations in the module. @@ -373,7 +391,7 @@ void DoExtractCommand(const HloModule& module, auto extracted_module = ExtractModule(instr, height); std::cout << extracted_module->ToString( HloPrintOptions::ShortParsable().set_print_backend_config( - show_backend_config)) + hlo_render_options.show_backend_config)) << std::endl; } @@ -517,7 +535,7 @@ void DoAllPathsCommand(const Options& opts, const HloModule& module, } RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { return RenderAllPathsFromTo(*from, *to, max_nodes, format, - /*show_backend_config=*/show_backend_config); + hlo_render_options); }); } @@ -582,15 +600,13 @@ void DoPlotCommand(const Options& opts, const HloModule& module, RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { return RenderGraph(*comp, /*label=*/"", comp->parent()->config().debug_options(), format, - /*hlo_execution_profile=*/nullptr, - /*show_backend_config=*/show_backend_config); + /*hlo_execution_profile=*/nullptr, hlo_render_options); }); } else { RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) { - return RenderNeighborhoodAround( - *instr, graph_width, format, - /*show_backend_config=*/show_backend_config, - /*boundary=*/boundary); + return RenderNeighborhoodAround(*instr, graph_width, format, + hlo_render_options, + /*boundary=*/boundary); }); } } @@ -617,6 +633,8 @@ void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { DoHelpCommand(); } else if (tokens[0] == "backend_config") { DoBackendConfigCommand(tokens); + } else if (tokens[0] == "show_fusion_subcomputations") { + DoShowFusionSubcomputationsCommand(tokens); } else if (tokens[0] == "list") { if (tokens.size() > 1 && tokens[1] == "computations") { DoListComputationsCommand(module, tokens); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index a015af674af..6595bcbe292 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -148,9 +148,20 @@ message DebugOptions { // xla_cpu_enable_fast_math is false. bool xla_cpu_fast_math_honor_functions = 129; + // When false we lower the Minimum and Maximum hlos in the CPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + // this is false we always propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the gpu flag + // below! + bool xla_cpu_enable_fast_min_max = 140; + // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the cpu flag + // above! bool xla_gpu_enable_fast_min_max = 100; // Allows xla to increase the output precision of floating point operations. @@ -276,18 +287,16 @@ message DebugOptions { // memory, or have bugs. bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138; - // It is usually preferable to not fallback to the driver; it can consume more - // memory, or have bugs. - bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_error = 139; - - // Next id: 140 + // Next id: 141 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; - reserved 5, 117, 133; // were xla_hlo_dump_as_graphdef, xla_dump_to, and - // xla_gpu_use_horizontal_fusion + reserved 5, 117, 133, + 139; // were xla_hlo_dump_as_graphdef, xla_dump_to, + // xla_gpu_use_horizontal_fusion, and + // xla_gpu_unsafe_fallback_to_driver_on_ptxas_error } // These settings control how XLA compiles and/or runs code. Not all settings @@ -333,6 +342,10 @@ message ExecutionOptions { // Used to identify a set of programs that should be launch together. int32 launch_id = 10; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning = 11; } message GetDeviceHandlesRequest { diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index d71e6e2cc73..494ba29e981 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":xrt_state_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 1bcd8561e61..ba6e6a093d6 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -158,7 +158,7 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, argument_layout_ptrs[i] = &argument_layouts[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client->default_device_ordinal()); + build_options.set_device_ordinal(device_ref.device_ordinal()); build_options.set_num_replicas(num_replicas); build_options.set_result_layout(xla::Shape(config.program_shape().result())); build_options.set_device_allocator(device_ref.backend()->memory_allocator()); @@ -206,7 +206,8 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key)); // Process-wide cache of XLA executables. - auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); OP_REQUIRES_OK(ctx, cache_or.status()); auto cache = cache_or.ConsumeValueOrDie(); @@ -259,15 +260,11 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell()); - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); - // Process-wide cache of XLA executables. - XRTCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->Lookup( - rm->default_container(), - kXRTCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); + OP_REQUIRES_OK(ctx, cache_or.status()); + auto cache = cache_or.ConsumeValueOrDie(); const Tensor& keys_tensor = ctx->input(0); auto flat_keys = keys_tensor.flat(); diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index b641f333e8b..2fc599e42df 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" @@ -38,7 +39,11 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/monitoring/timed.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -146,16 +151,245 @@ xla::StatusOr GetChainedOpInputs( return std::move(input_buffers); } +// Given a shape, returns a byte array representing the shape metadata of the +// shape. The shape metadata contains dimensions sizes stored as contiguous S32. +std::vector PrepareMetadata(const xla::Shape& shape) { + DCHECK(shape.is_static()); + DCHECK(shape.IsArray()); + // Each dimension size is stored as a S32. + std::vector result(shape.dimensions_size()); + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + result[i] = shape.dimensions(i); + } + return result; +} + +// Given a buffer with dynamic shape, update buffer metadata at the correct +// offset starting from that buffer. +// +// +-----------+ +// |Payload | +// +-----------+ +// | Padding | +// +-----------+ +// |dim_size_0 | (each dim_size is a S32): +// +-----------+ +// |dim_size_1 | +// +-----------+ +// .......... +// +-----------+ +// +// Size of payload = ByteSizeOf(runtime_shape) +// Size of payload + padding = ByteSizeOf(compile_time_shape_static) +// Size of payload + padding + metadata = ByteSizeOf(compile_time_shape) +Status UpdateMetadata(se::Stream* stream, se::DeviceMemory* buffer, + const xla::Shape& compile_time_shape, + const xla::Shape& runtime_shape) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape compile_time_shape_static = + xla::ShapeUtil::MakeStaticShape(compile_time_shape); + uint64 offset = shape_size_fn(compile_time_shape_static); + uint64 metadata_size = shape_size_fn(compile_time_shape) - offset; + auto metadata_buffer = + stream->parent()->GetSubBuffer(buffer, offset, metadata_size); + + auto metadata_literal = std::make_shared( + xla::LiteralUtil::CreateR1(PrepareMetadata(runtime_shape))); + TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync( + stream, *metadata_literal, metadata_buffer)); + // Retain the literal until the end of the transfer. + stream->ThenDoHostCallback([metadata_literal]() { return Status::OK(); }); + return Status::OK(); +} + +// Given a static input buffer, convert it to dynamic form by expanding it to +// the bounded size and attaching a metadata filled with dimension sizes. +// +// From: +// +--------+ +// |Payload | +// +--------+ +// +// To: +// +// +--------+ +// |Payload | +// +--------+ +// | Padding| +// +--------+ +// |Metadata| +// +--------+ +// +// As we can't expand the size of an existing memory allocation, a reallocation +// is required. A list of new allocations are returned after this function. The +// caller is reponsible for maintaining those allocations. +xla::StatusOr> UpdateDynamicInputs( + se::Stream* stream, se::DeviceMemoryAllocator* allocator, + std::vector runtime_inputs, + const std::vector& compile_time_shapes) { + std::vector new_allocations; + TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size()); + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + for (int64 i = 0; i < compile_time_shapes.size(); i++) { + const xla::Shape& compile_time_shape = compile_time_shapes[i].shape(); + if (compile_time_shape.is_static()) { + continue; + } + auto* runtime_input = runtime_inputs[i]; + + bool element_modified = false; + TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( + compile_time_shape, + [&](const xla::Shape& compile_time_shape, + const xla::ShapeIndex& index) -> Status { + if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) { + return Status::OK(); + } + const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape( + runtime_input->on_device_shape(), index); + TF_RET_CHECK(!runtime_shape.IsTuple()); + TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible( + runtime_shape, compile_time_shape)); + se::DeviceMemoryBase* static_input = + runtime_input->buffers().mutable_element(index); + TF_ASSIGN_OR_RETURN( + auto dynamic_input, + allocator->Allocate(stream->parent()->device_ordinal(), + shape_size_fn(compile_time_shape))); + new_allocations.emplace_back(std::move(dynamic_input)); + se::DeviceMemory* dynamic_input_base = + new_allocations.back().ptr(); + // Send the original data to the new location. + stream->ThenMemcpyD2D(dynamic_input_base, *static_input, + static_input->size()); + TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base, + compile_time_shape, runtime_shape)); + // Modify the memory location in the input shape tree to point to the + // new input. + runtime_input->set_buffer(*dynamic_input_base, index); + element_modified = true; + return Status::OK(); + })); + if (element_modified) { + runtime_input->set_shapes(compile_time_shape, compile_time_shape); + // The input location has been modified, need to fix tuple table to + // point to the correct address. + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR( + transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input)); + } + } + return std::move(new_allocations); +} + +xla::StatusOr ReadMetadataLiteral( + se::Stream* stream, se::DeviceMemoryBase* buffer, + const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape buffer_shape_static = + xla::ShapeUtil::MakeStaticShape(buffer_shape); + const int64 offset = shape_size_fn(buffer_shape_static); + int64 metadata_size = shape_size_fn(buffer_shape) - offset; + TF_RET_CHECK(metadata_size != 0); + auto buffer_8 = se::DeviceMemory(*buffer); + auto metadata_buffer = + stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + return transfer_manager->TransferArrayFromDevice( + stream, + xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}), + metadata_buffer); +} + +// For each subshape in the result buffer that's dynamic, read the dynamic +// dimension sizes from the metadata, and update output shapes. The result shape +// is a static and concrete shape. +xla::Status UpdateDynamicOutputs(se::Stream* stream, + xla::ShapedBuffer* shaped_buffer, + xla::Shape* output_host_shape, + xla::Shape* output_device_shape) { + DCHECK(output_device_shape->is_dynamic()); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const xla::Shape& buffer_shape = + xla::ShapeUtil::GetSubshape(*output_device_shape, index); + if (buffer_shape.IsTuple()) { + return Status::OK(); + } + xla::Shape& host_shape = + *xla::ShapeUtil::GetMutableSubshape(output_host_shape, index); + xla::Shape& device_shape = + *xla::ShapeUtil::GetMutableSubshape(output_device_shape, index); + if (device_shape.is_static()) { + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(auto metadata, + ReadMetadataLiteral(stream, buffer, buffer_shape, + transfer_manager)); + // Update shape size from metadata. + for (int64 i = 0; i < metadata.element_count(); ++i) { + host_shape.mutable_dimensions()[i] = metadata.Get({i}); + device_shape.mutable_dimensions()[i] = metadata.Get({i}); + } + return Status::OK(); + })); + output_host_shape->clear_dynamic_dimensions(); + output_device_shape->clear_dynamic_dimensions(); + return Status::OK(); +} + +// Create output tuple from run_result. +xla::StatusOr> CreateOutputTuple( + se::Stream* stream, xla::ScopedShapedBuffer run_result, + xla::Backend* backend, int device_ordinal) { + XRTTupleAllocation* output_tuple; + xla::ShapedBuffer shaped_buffer = run_result.release(); + if (shaped_buffer.on_device_shape().is_dynamic()) { + // Update dynamic shapes from output buffer, and create a XRT tensor with + // dimension sizes read from metadata. + xla::Shape output_host_shape = shaped_buffer.on_host_shape(); + xla::Shape output_device_shape = shaped_buffer.on_device_shape(); + TF_RETURN_IF_ERROR(UpdateDynamicOutputs( + stream, &shaped_buffer, &output_host_shape, &output_device_shape)); + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, output_host_shape, output_device_shape, backend, + device_ordinal, &output_tuple)); + } else { + // Fast-path: Don't copy shapes of output buffer. + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, backend, device_ordinal, &output_tuple)); + } + return RefPtr(output_tuple); +} + xla::StatusOr> RunExecutable( OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed, int replica_id) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(device_ref->backend()->memory_allocator()); run_options.set_intra_op_thread_pool(&context->eigen_cpu_device()); run_options.set_rng_seed(rng_seed); + if (config.run_id() != 0) { + run_options.set_run_id(xla::RunId(config.run_id())); + } if (executable->executable() ->module_config() .has_static_device_assignment()) { @@ -164,8 +398,11 @@ xla::StatusOr> RunExecutable( } xla::GpuExecutableRunOptions gpu_options; std::vector gpu_global_ids; - if (replica_id >= 0) { - gpu_global_ids.emplace_back(replica_id); + if (config.local_replica_mapping_size() > 0) { + gpu_global_ids.reserve(config.local_replica_mapping_size()); + for (auto& gid : config.local_replica_mapping()) { + gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid)); + } gpu_options.set_gpu_global_device_ids(gpu_global_ids); } std::shared_ptr nccl_factory = GetNcclUniqueIdFactory(); @@ -184,18 +421,31 @@ xla::StatusOr> RunExecutable( Env* env = Env::Default(); auto start_time = env->NowMicros(); + const std::vector& shape_layouts = + executable->executable() + ->module_config() + .entry_computation_layout() + .parameter_layouts(); + TF_ASSIGN_OR_RETURN(auto new_allocations, + UpdateDynamicInputs(stream, run_options.allocator(), + input_buffers.input_pointers, + shape_layouts)); + auto new_allocations_ptr = + std::make_shared>( + std::move(new_allocations)); TF_ASSIGN_OR_RETURN( xla::ScopedShapedBuffer run_result, executable->Run(input_buffers.input_pointers, run_options)); + // Retain the new allocation for input memory until the end of execution. + stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); }); + auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - auto shaped_buffer = run_result.release(); - XRTTupleAllocation* output_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, device_ref->backend(), device_ref->device_ordinal(), - &output_tuple)); - RefPtr output_tuple_ptr(output_tuple); + TF_ASSIGN_OR_RETURN( + RefPtr output_tuple_ptr, + CreateOutputTuple(stream, std::move(run_result), device_ref->backend(), + device_ref->device_ordinal())); // The ScopedShapedBuffer returned by the executable Run() API, in case of // input/output buffer aliasing, might have holes in it, which need to be @@ -208,7 +458,7 @@ xla::StatusOr> RunExecutable( const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size()); return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias - ? output_tuple->AliasBufferFrom( + ? output_tuple_ptr->AliasBufferFrom( *input_buffers.input_tuples[alias.parameter_number], alias.parameter_index, output_index) : Status::OK(); @@ -222,10 +472,11 @@ xla::StatusOr> ExecuteComputation( OpKernelContext* context, XRTMemoryManager* memory_manager, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed, int replica_id) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { auto runfn = [&]() { return RunExecutable(context, device_ref, executable, input_buffers, stream, - rng_seed, replica_id); + rng_seed, config); }; // We pass zero as requested_free_size as there is no simple way to get the @@ -241,14 +492,15 @@ xla::StatusOr> ExecuteComputation( XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const std::vector& input_coords, bool release_inputs, - se::Stream* stream, int rng_seed, int replica_id) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { XRTMemoryManager::WorkingSet working_set(memory_manager); TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, GetInputBuffers(&working_set, device_ref->backend(), input_coords, release_inputs)); return ExecuteComputation(context, memory_manager.get(), device_ref, executable, input_buffers, stream, rng_seed, - replica_id); + config); } // XRTExecuteOp @@ -297,8 +549,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { bool release_inputs = config_proto.release_input_handles(); bool release_compilation = config_proto.release_compilation_handle(); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -330,7 +583,7 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { RefPtr output_tuple, ExecuteComputation(context, memory_manager, &device_ref, executable, input_coords, release_inputs, stream, rng_seed, - config_proto.replica_id())); + config_proto.common_config())); return CreateExecuteOutput(context, memory_manager.get(), std::move(output_tuple), @@ -379,8 +632,9 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { xrt::XRTChainedExecuteConfig config; TF_RET_CHECK(ParseFromTString(execution_config.scalar()(), &config)); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -408,7 +662,7 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { return ExecuteComputation(context, memory_manager.get(), &device_ref, executable, input_buffers, stream, rng_seed, - config.replica_id()); + config.common_config()); }; return ExecuteChained(context, memory_manager, device_ref.backend(), diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 243289c8821..67647cc4285 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -49,6 +49,91 @@ limitations under the License. namespace tensorflow { namespace { +xla::XlaComputation ReturnDynamicR1() { + xla::XlaBuilder builder("ReturnDynamicR1"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + return builder.Build(pad_sum).ValueOrDie(); +} + +xla::XlaComputation ReturnDynamicR2() { + xla::XlaBuilder builder("ReturnDynamicR2"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto pad_sum_dim0 = xla::SetDimensionSize(sum, p2, 0); + auto pad_sum_dim1 = xla::SetDimensionSize(pad_sum_dim0, p2, 1); + return builder.Build(pad_sum_dim1).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); + auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1"); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR2() { + xla::XlaBuilder builder("AcceptDynamicR2"); + xla::Shape dyn_shape; + dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_shape.set_dynamic_dimension(1, true); + auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); + auto negate = xla::Neg(p0); + return builder.Build(negate).ValueOrDie(); +} + +xla::XlaComputation ReturnDynamicR1Tuple() { + xla::XlaBuilder builder("ReturnDynamicR1Tuple"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + auto one = xla::One(&builder, xla::S32); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0); + auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub}); + return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1Tuple() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + xla::Shape tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + xla::Shape nest_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + auto p = xla::Parameter(&builder, 0, tuple_shape, "P0"); + auto p0 = xla::GetTupleElement(p, 0); + auto p1 = xla::GetTupleElement(p, 1); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +template +xla::LiteralProto CreateR0(T v) { + auto array = xla::LiteralUtil::CreateR0(v); + return array.ToProto(); +} + class XrtClientSession : public ClientSession { public: explicit XrtClientSession(const Scope& scope) : ClientSession(scope) { @@ -61,6 +146,11 @@ class XrtClientSession : public ClientSession { string* xla_test_device_ptr; // initial value set in main() string* xla_platform_ptr; // initial value set in main() +bool SupportDynamicShapes() { + // TODO(jackcao): Support dynamic shapes on XLA GPU. + return *xla_test_device_ptr != "XLA_GPU"; +} + string DeviceFromFlag() { string xla_test_device = *xla_test_device_ptr; return absl::StrCat("/device:", xla_test_device, ":0"); @@ -1035,6 +1125,353 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_EQ(program_shape.parameters_size(), 2); } +TEST(RawApiTest, DynamicR1Test) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, DynamicR2Test) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f}, + {1.5f, 2.5f, 3.0f, -2.0f}}) + .ToProto(); + xrt::XLAAllocation p1; + *p1.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, -1.0f, 2.5f, 1.17f}, + {1.2f, -1.6f, 2.8f, 1.24f}}) + .ToProto(); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_shape.set_dynamic_dimension(0, true); + dyn_shape.set_dynamic_dimension(1, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(ReturnDynamicR2(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + auto expected = xla::LiteralUtil::CreateR2({{2.0f, 1.0f}, {2.7, 0.9}}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, DynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape( + {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape}) + .ToProto(); + StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected0 = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + auto expected1 = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f, 0.0f}); + auto expected2 = xla::LiteralUtil::CreateR1({0.0f, 3.0f, 1.0f}); + auto expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLATupleNode tuple_desc; + auto subdesc_10 = tuple_desc.add_tuples(); + auto subdesc_11 = tuple_desc.add_tuples(); + subdesc_10->set_input_index(0); + subdesc_10->set_release_input_handle(true); + subdesc_11->set_input_index(1); + subdesc_11->set_release_input_handle(true); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + xla::Shape dyn_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape}); + *shapes->add_parameters() = dyn_tuple_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + + auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"), + tuple_desc.SerializeAsString()); + auto t0_handle = ops::XRTMakeTuple( + root, tuple_0, + {static_cast(p0_handle), static_cast(p1_handle)}); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {static_cast(t0_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1Test) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto allocate_op_0 = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto allocate_op_1 = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(allocate_op_0), Output(allocate_op_1)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR2Test) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = + xla::LiteralUtil::CreateR2({{-1.0f, 3.0f, 1.0f}, {-2.0f, -1.0f, 3.0f}}) + .ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + // Compile time expects ascending layout. + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_shape.set_dynamic_dimension(1, true); + *shapes->add_parameters() = dyn_shape.ToProto(); + + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR2(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto result = + ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR2( + {{1.0f, -3.0f, -1.0f}, {2.0f, 1.0f, -3.0f}}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f}); diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 1cbd851f7ef..9a351732c4b 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -111,6 +111,17 @@ message XLATupleNode { repeated XLATupleNode tuples = 3; } +message CommonExecutionConfig { + // The replica index this execute is driving. + int32 replica_id = 1; + // Mapping local device ordinals to global replica IDs. + // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID + repeated int32 local_replica_mapping = 2; + // The execution run ID used to correlate different XRT execute operations + // happeining in parallel from different threads. + int64 run_id = 3; +} + // Options for an XLA execution. message XRTExecutionConfig { // Local device to run on. This is present because the execute Op @@ -133,8 +144,9 @@ message XRTExecutionConfig { // a single tuple allocation the execution will return a vector of // allocations, one for each of the first-level elements of the result tuple. bool return_exploded_tuple = 7; - // The replica index this execute is driving. - int32 replica_id = 8; + reserved 8; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 9; } message XRTChainedExecuteConfig { @@ -145,8 +157,9 @@ message XRTChainedExecuteConfig { // Optional key to disambiguate between executions. This is only needed if // multiple host send/recvs may be outstanding concurrently with executions. string execution_instance_key = 3; - // The replica index this execute is driving. - int32 replica_id = 4; + reserved 4; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 5; } // A single chained execute operation. An operation can either be a device data diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc index 1b5557d556d..46954572c5d 100644 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ b/tensorflow/compiler/xrt/xrt_device.cc @@ -17,19 +17,56 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_device.h" +#include + #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { +namespace { + +class ResourceMgrArena { + public: + static ResourceMgrArena* Get() { + static ResourceMgrArena* arena = new ResourceMgrArena(); + return arena; + } + + ResourceMgr* GetResourceMgr(const std::string& platform_name) { + mutex_lock lock(mutex_); + auto it = resource_managers_.find(platform_name); + if (it == resource_managers_.end()) { + it = resource_managers_.emplace(platform_name, new ResourceMgr()).first; + } + return it->second; + } + + private: + mutex mutex_; + std::map resource_managers_; +}; + +} // namespace /*static*/ Status XRTGenericDeviceAccessor::GetResourceManager( OpKernelContext* ctx, ResourceMgr** rm) { - *rm = ctx->resource_manager(); + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + *rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name()); return Status::OK(); } +/* static */ xla::StatusOr> +XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries) { + ResourceMgr* rm; + TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm)); + return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries); +} + /*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) { const XlaDevice::Metadata* metadata; diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h index 5ebee7641f0..02fab315830 100644 --- a/tensorflow/compiler/xrt/xrt_device.h +++ b/tensorflow/compiler/xrt/xrt_device.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_ #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -31,6 +32,9 @@ class XRTGenericDeviceAccessor { public: static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); + static xla::StatusOr> GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries); + // We use a ScopedRef pattern here even though it's not strictly necessary, // just so that templated uses of this and the TPU accessor class will be as // similar as possible. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b4bec2a6907..2b16801f6ed 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -83,7 +83,6 @@ load( "tf_gen_op_libs", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_lite_protos", - "tf_opts_nortti_if_mobile", "tf_portable_full_lite_protos", "transitive_hdrs", ) @@ -100,28 +99,23 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") -# buildifier: disable=same-origin-load -# Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib") - # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_monitoring_deps") # For platform specific build config load( "//tensorflow/core/platform:build_config.bzl", - "tf_additional_all_protos", "tf_additional_lib_deps", "tf_additional_test_deps", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", "tf_lib_proto_parsing_deps", "tf_portable_deps_no_runtime", + "tf_portable_proto_lib", "tf_proto_library", - "tf_proto_library_cc", "tf_protos_all_impl", "tf_protos_grappler_impl", "tf_protos_profiler_impl", - "tf_pyclif_proto_library", ) load( "//tensorflow/core/platform:rules_cc.bzl", @@ -184,18 +178,18 @@ package_group(name = "friends") # filegroup; e.g. ones with individual proto_library targets. # LINT.IfChange COMMON_PROTO_SRCS = [ - "protobuf/bfc_memory_map.proto", - "protobuf/config.proto", - "protobuf/cluster.proto", - "protobuf/debug.proto", - "protobuf/device_filters.proto", - "protobuf/device_properties.proto", - "protobuf/graph_debug_info.proto", - "protobuf/queue_runner.proto", - "protobuf/rewriter_config.proto", - "protobuf/tensor_bundle.proto", - "protobuf/saver.proto", - "protobuf/verifier_config.proto", + "//tensorflow/core/protobuf:bfc_memory_map.proto", + "//tensorflow/core/protobuf:config.proto", + "//tensorflow/core/protobuf:cluster.proto", + "//tensorflow/core/protobuf:debug.proto", + "//tensorflow/core/protobuf:device_filters.proto", + "//tensorflow/core/protobuf:device_properties.proto", + "//tensorflow/core/protobuf:graph_debug_info.proto", + "//tensorflow/core/protobuf:queue_runner.proto", + "//tensorflow/core/protobuf:rewriter_config.proto", + "//tensorflow/core/protobuf:tensor_bundle.proto", + "//tensorflow/core/protobuf:saver.proto", + "//tensorflow/core/protobuf:verifier_config.proto", ] EXAMPLE_PROTO_SRCS = [ @@ -242,7 +236,7 @@ PROFILER_PROTO_SRCS = [ ] ERROR_CODES_PROTO_SRCS = [ - "protobuf/error_codes.proto", + "//tensorflow/core/protobuf:error_codes.proto", "//tensorflow/core/lib/core:error_codes.proto", ] # LINT.ThenChange(//tensorflow/core/portable_proto_config.asciipb) @@ -255,11 +249,13 @@ tf_proto_library( cc_api_version = 2, make_default_target_header_only = True, protodeps = [ - ":core_protos", - ":error_codes_proto_impl", "//tensorflow/core/example:protos_all", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", + "//tensorflow/core/profiler/protobuf:xplane_proto", + "//tensorflow/core/profiler:profiler_options_proto", + "//tensorflow/core/protobuf:error_codes_proto_impl", + "//tensorflow/core/protobuf:for_core_protos", "//tensorflow/core/util:protos_all", "//tensorflow/core/util:test_log_proto_impl", ], @@ -1274,7 +1270,7 @@ filegroup( "//tensorflow/core/platform:mobile_srcs_no_runtime", "//tensorflow/core/public:mobile_srcs_no_runtime", "//tensorflow/core/util:mobile_srcs_no_runtime", - "//tensorflow/core/util/ctc:android_srcs", + "//tensorflow/core/util/ctc:mobile_srcs", ] + glob( [ "client/**/*.cc", @@ -1304,12 +1300,12 @@ filegroup( "//tensorflow/core/common_runtime/eager:srcs", "//tensorflow/core/framework:mobile_srcs_only_runtime", "//tensorflow/core/graph:mobile_srcs_only_runtime", - "//tensorflow/core/kernels:android_srcs", + "//tensorflow/core/kernels:mobile_srcs", "//tensorflow/core/lib/io:mobile_srcs_only_runtime", "//tensorflow/core/profiler:mobile_srcs", "//tensorflow/core/public:mobile_srcs_only_runtime", "//tensorflow/core/util/sparse:mobile_srcs_only_runtime", - "//tensorflow/core/util/tensor_bundle:android_srcs", + "//tensorflow/core/util/tensor_bundle:mobile_srcs", "//tensorflow/core/util:mobile_srcs_only_runtime", # Sources for which we already have granular targets. @@ -1382,10 +1378,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":protos_all_cc_impl", "//tensorflow/core/util:stats_calculator_portable", "//tensorflow/core:mobile_additional_lib_deps", - ] + tf_portable_deps_no_runtime(), + ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, ) @@ -1417,54 +1412,12 @@ cc_library( ], ) -# Native library support for iOS applications. -# -# bazel build --config=ios_x86_64 \ -# :ios_tensorflow_lib -cc_library( - name = "ios_tensorflow_lib", - srcs = if_ios([ - ":portable_op_registrations_and_gradients", - "//tensorflow/core/kernels:android_core_ops", - "//tensorflow/core/kernels:android_extended_ops", - ]), - copts = tf_copts() + tf_opts_nortti_if_lite_protos() + ["-Os"], - visibility = ["//visibility:public"], - deps = [ - ":portable_tensorflow_lib_lite", - ":protos_all_cc_impl", - "//third_party/eigen3", - "//third_party/fft2d:fft2d_headers", - "@com_google_protobuf//:protobuf", - "@fft2d", - "@gemmlowp", - ], - alwayslink = 1, -) - alias( name = "ios_tensorflow_lib_lite", actual = ":portable_tensorflow_lib_lite", visibility = ["//visibility:public"], ) -cc_library( - name = "ios_tensorflow_test_lib", - testonly = 1, - srcs = if_ios([":android_test_srcs"]), - copts = tf_copts() + ["-Os"], - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":ios_tensorflow_lib", - "//tensorflow/core/platform/default/build_config:gtest", - "//third_party/eigen3", - ], -) - # Full TensorFlow library with operator support. Use this unless reducing # binary size (by packaging a reduced operator set) is a concern. alias( @@ -1473,10 +1426,16 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "ios_tensorflow_lib", + actual = ":portable_tensorflow_lib", + visibility = ["//visibility:public"], +) + cc_library( name = "portable_tensorflow_lib", srcs = if_mobile([":portable_op_registrations_and_gradients"]), - copts = tf_copts() + tf_opts_nortti_if_lite_protos(), + copts = tf_copts() + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), features = tf_features_nomodules_if_mobile(), tags = [ "manual", @@ -1559,6 +1518,12 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "ios_tensorflow_test_lib", + actual = ":portable_tensorflow_test_lib", + visibility = ["//visibility:public"], +) + cc_library( name = "portable_tensorflow_test_lib", testonly = 1, @@ -1569,7 +1534,7 @@ cc_library( "//tensorflow/core/framework:android_test_hdrs", "//tensorflow/core/util:android_test_hdrs", ], - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + if_ios(["-Os"]), features = tf_features_nomodules_if_mobile() + tf_opts_nortti_if_lite_protos(), tags = [ "manual", @@ -1637,20 +1602,13 @@ alias( [ alias( name = "protobuf_%s_pyclif%s" % (proto_name, target_suffix), - actual = ":protobuf/%s_pyclif%s" % (proto_name, target_suffix), + actual = "//tensorflow/core/protobuf:%s_pyclif%s" % (proto_name, target_suffix), visibility = ["//visibility:public"], ) for target_suffix in [ "", "_pb2", ] - ] + [ - tf_pyclif_proto_library( - name = "protobuf/%s_pyclif" % proto_name, - proto_lib = ":protos_all", - proto_srcfile = "protobuf/%s.proto" % proto_name, - visibility = ["//visibility:public"], - ), ] for proto_name in [ "config", @@ -1664,77 +1622,74 @@ alias( # ----------------------------------------------------------------------------- # Internal targets -tf_proto_library( +alias( name = "autotuning_proto", - srcs = ["protobuf/autotuning.proto"], - cc_api_version = 2, - make_default_target_header_only = True, + actual = "//tensorflow/core/protobuf:autotuning_proto", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library( +alias( + name = "autotuning_proto_cc", + actual = "//tensorflow/core/protobuf:autotuning_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( name = "conv_autotuning_proto", - srcs = ["protobuf/conv_autotuning.proto"], - cc_api_version = 2, - make_default_target_header_only = True, - protodeps = [ - "//tensorflow/stream_executor:dnn_proto", - ], + actual = "//tensorflow/core/protobuf:conv_autotuning_proto", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "worker_proto", - srcs = ["protobuf/worker.proto"], - cc_api_version = 2, - protodeps = tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -tf_proto_library_cc( - name = "worker_service_proto", - srcs = ["protobuf/worker_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_stubby_versions = ["2"], - protodeps = [":worker_proto"], +alias( + name = "conv_autotuning_proto_cc", + actual = "//tensorflow/core/protobuf:conv_autotuning_proto_cc", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "master_proto", - srcs = ["protobuf/master.proto"], - cc_api_version = 2, - protodeps = tf_additional_all_protos(), - visibility = ["//tensorflow:internal"], -) - -tf_proto_library_cc( - name = "master_service_proto", - srcs = ["protobuf/master_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_stubby_versions = ["2"], - protodeps = [":master_proto"], +alias( + name = "worker_proto_cc", + actual = "//tensorflow/core/protobuf:worker_proto_cc", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "eager_service_proto", - srcs = ["protobuf/eager_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - cc_stubby_versions = ["2"], - protodeps = tf_additional_all_protos(), +alias( + name = "worker_service_proto_cc", + actual = "//tensorflow/core/protobuf:worker_service_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( + name = "master_proto_cc", + actual = "//tensorflow/core/protobuf:master_proto_cc", + visibility = [ + "//learning/brain/frameworks/uptc:__subpackages__", + "//tensorflow:internal", + ], +) + +alias( + name = "master_service_proto_cc", + actual = "//tensorflow/core/protobuf:master_service_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( + name = "eager_service_proto_cc", + actual = "//tensorflow/core/protobuf:eager_service_proto_cc", visibility = [ "//tensorflow:internal", ], @@ -2146,49 +2101,14 @@ cc_library( ], ) -tf_proto_library( +alias( name = "error_codes_proto_impl", - srcs = ["protobuf/error_codes.proto"], - cc_api_version = 2, - make_default_target_header_only = True, + actual = "//tensorflow/core/protobuf:error_codes_proto_impl", ) -tf_proto_library( - name = "core_protos", - srcs = COMMON_PROTO_SRCS + [ - # Protos which are not needed on mobile builds, but should be included - # in protos_all. - # - # Note that some protos are in neither core_proto_srcs nor this - # filegroup; e.g. ones with individual proto_library targets. - "protobuf/control_flow.proto", - # TODO(ebrevdo): Re-enable once CriticalSection is in core. - # "protobuf/critical_section.proto", - "protobuf/data/experimental/snapshot.proto", - "protobuf/debug_event.proto", - "protobuf/meta_graph.proto", - "protobuf/named_tensor.proto", - "protobuf/remote_tensor_handle.proto", - "protobuf/saved_model.proto", - "protobuf/saved_object_graph.proto", - "protobuf/struct.proto", - "protobuf/tensorflow_server.proto", - "protobuf/trackable_object_graph.proto", - "protobuf/transport_options.proto", - ], - cc_api_version = 2, - make_default_target_header_only = True, - protodeps = [ - ":error_codes_proto_impl", - "//tensorflow/core/example:protos_all", - "//tensorflow/core/framework:protos_all", - "//tensorflow/core/lib/core:error_codes_proto", - "//tensorflow/core/profiler/protobuf:xplane_proto", - "//tensorflow/core/profiler:profiler_options_proto", - "//tensorflow/core/util:protos_all", - "//tensorflow/core/util:test_log_proto_impl", - ], - visibility = ["//visibility:private"], +alias( + name = "error_codes_proto_impl_cc", + actual = "//tensorflow/core/protobuf:error_codes_proto_impl_cc", ) alias( @@ -2334,6 +2254,7 @@ tf_cuda_library( "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/tpu:tpu_library_loader", "//tensorflow/core/util:einsum_op_util", "//tensorflow/core/util:padding", "//tensorflow/core/util:port", @@ -2480,13 +2401,9 @@ alias( visibility = ["//visibility:public"], ) -tf_proto_library_cc( - name = "replay_log_proto", - srcs = ["protobuf/replay_log.proto"], - cc_api_version = 2, - protodeps = [ - ":master_proto", - ] + tf_additional_all_protos(), +alias( + name = "replay_log_proto_cc", + actual = "//tensorflow/core/protobuf:replay_log_proto_cc", visibility = [ "//tensorflow:internal", ], @@ -3117,6 +3034,11 @@ alias( actual = "//tensorflow/core/platform:cuda_libdevice_path", ) +# Normalize CORE_PROTO_SRCS to generate valid output file names. +PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ + "//google/protobuf/any.proto.h", +] + transitive_hdrs( name = "headers", visibility = ["//tensorflow:__subpackages__"], @@ -3129,8 +3051,3 @@ transitive_hdrs( "//tensorflow/core/platform:platform_strings", ], ) - -# Normalize CORE_PROTO_SRCS to generate valid output file names. -PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ - "//google/protobuf/any.proto.h", -] diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyFtrl.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyFtrl.pbtxt index 0f49a18a114..f3379461a5f 100644 --- a/tensorflow/core/api_def/base_api/api_def_ApplyFtrl.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ApplyFtrl.pbtxt @@ -65,7 +65,7 @@ END summary: "Update \'*var\' according to the Ftrl-proximal scheme." description: < l1 else 0.0 accum = accum_new diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt index 3218ab7776c..1eb33005e91 100644 --- a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt @@ -65,8 +65,8 @@ END summary: "Update \'*var\' according to the Ftrl-proximal scheme." description: < l1 else 0.0 diff --git a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt index 09eff6177b1..ae5942b3617 100644 --- a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt @@ -84,6 +84,13 @@ END name: "Tout" description: <