diff --git a/.bazelrc b/.bazelrc index bdfd910d431..224238d7c0b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -163,6 +163,8 @@ build:cuda_clang --action_env TF_CUDA_CLANG=1 build:dbg --config=opt -c dbg # for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360 build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON +# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 +build:dbg --copt -DDEBUG_BUILD build:tensorrt --action_env TF_NEED_TENSORRT=1 @@ -356,9 +358,10 @@ build:rbe_linux --linkopt=-lm build:rbe_cpu_linux --config=rbe_linux build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8" -build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" -build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" -build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010" +build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" +build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" +build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" +build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_linux_cuda_base --config=rbe_linux build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1 diff --git a/README.md b/README.md index 27032043e07..ba4597af14c 100644 --- a/README.md +++ b/README.md @@ -103,17 +103,17 @@ open-source software development: ### Official Builds -Build Type | Status | Artifacts ------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA -**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) -**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) -**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) -**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) -**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) +Build Type | Status | Artifacts +------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA +**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/) +**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/) +**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) +**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) ### Community Supported Builds diff --git a/RELEASE.md b/RELEASE.md index b5d088821e4..6c8921cf492 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,147 @@ +# Release 2.2.0 + +TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). + +Coinciding with this change, new releases of [TensorFlow's Docker images](https://hub.docker.com/r/tensorflow/tensorflow/) provide Python 3 exclusively. Because all images now use Python 3, Docker tags containing `-py3` will no longer be provided and existing `-py3` tags like `latest-py3` will not be updated. + +## Major Features and Improvements + +* Replaced the scalar type for string tensors from `std::string` to `tensorflow::tstring` which is now ABI stable. +* A new Profiler for TF 2 for CPU/GPU/TPU. It offers both device and host performance analysis, including input pipeline and TF Ops. Optimization advisory is provided whenever possible. Please see [this tutorial](https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras) and [guide](https://www.tensorflow.org/guide/profiler) for usage guidelines. +* Export C++ functions to Python using `pybind11` as opposed to `SWIG` as a part of our [deprecation of swig efforts](https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md). +* `tf.distribute`: + * Support added for global sync `BatchNormalization` by using the newly added `tf.keras.layers.experimental.SyncBatchNormalization` layer. This layer will sync `BatchNormalization` statistics every step across all replicas taking part in sync training. + * Performance improvements for GPU multi-worker distributed training using `tf.distribute.experimental.MultiWorkerMirroredStrategy` + * Update NVIDIA `NCCL` to `2.5.7-1` for better performance and performance tuning. Please see [nccl developer guide](https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html) for more information on this. + * Support gradient `allreduce` in `float16`. See this [example](https://github.com/tensorflow/models/blob/master/official/staging/training/grad_utils.py) usage. + * Experimental support of [all reduce gradient packing](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/CollectiveHints) to allow overlapping gradient aggregation with backward path computation. + * Deprecated `experimental_run_v2` method for distribution strategies and renamed the method `run` as it is no longer experimental. + * Add CompositeTensor support for DistributedIterators. This should help prevent unnecessary function retracing and memory leaks. +* `tf.keras`: + * `Model.fit` major improvements: + * You can now use custom training logic with `Model.fit` by overriding `Model.train_step`. + * Easily write state-of-the-art training loops without worrying about all of the features `Model.fit` handles for you (distribution strategies, callbacks, data formats, looping logic, etc) + * See the default [`Model.train_step`](https://github.com/tensorflow/tensorflow/blob/1381fc8e15e22402417b98e3881dfd409998daea/tensorflow/python/keras/engine/training.py#L540) for an example of what this function should look like. Same applies for validation and inference via `Model.test_step` and `Model.predict_step`. + * SavedModel uses its own `Model._saved_model_inputs_spec` attr now instead of + relying on `Model.inputs` and `Model.input_names`, which are no longer set for subclass Models. + This attr is set in eager, `tf.function`, and graph modes. This gets rid of the need for users to + manually call `Model._set_inputs` when using Custom Training Loops(CTLs). + * Dynamic shapes are supported for generators by calling the Model on the first batch we "peek" from the generator. + This used to happen implicitly in `Model._standardize_user_data`. Long-term, a solution where the + `DataAdapter` doesn't need to call the Model is probably preferable. + * The SavedModel format now supports all Keras built-in layers (including metrics, preprocessing layers, and stateful RNN layers) + * Update Keras batch normalization layer to use the running mean and average computation in the `fused_batch_norm`. You should see significant performance improvements when using `fused_batch_norm` in Eager mode. + +* `tf.lite`: + * Enable TFLite experimental new converter by default. +* XLA + * XLA now builds and works on windows. All prebuilt packages come with XLA available. + * XLA can be [enabled for a `tf.function`](https://www.tensorflow.org/xla#explicit_compilation_with_tffunction +) with “compile or throw exception” semantics on CPU and GPU. + +## Breaking Changes +* `tf.keras`: + * In `tf.keras.applications` the name of the "top" layer has been standardized to "predictions". This is only a problem if your code relies on the exact name of the layer. + * Huber loss function has been updated to be consistent with other Keras losses. It now computes mean over the last axis of per-sample losses before applying the reduction function. +* AutoGraph no longer converts functions passed to `tf.py_function`, `tf.py_func` and `tf.numpy_function`. +* Deprecating `XLA_CPU` and `XLA_GPU` devices with this release. +* Increasing the minimum bazel version to build TF to 2.0.0 to use Bazel's `cc_experimental_shared_library`. +* Keras compile/fit behavior for functional and subclassed models have been unified. Model properties such as `metrics`, `metrics_names` will now be available only after **training/evaluating the model on actual data** for functional models. `metrics` will **now include** model `loss` and output losses.`loss_functions` property has been removed from the model. This was an undocumented property that was accidentally public and has now been removed. + +## Known Caveats +* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3. + +## Bug Fixes and Other Changes +* `tf.data`: + * Removed `autotune_algorithm` from experimental optimization options. +* TF Core: + * `tf.constant` always creates CPU tensors irrespective of the current device context. + * Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution. + * For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`. + * `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now. + * Set as much partial shape as we can infer statically within the gradient impl of the gather op. + * Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy. + * Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions. + * Support `back_prop=False` in `while_v2` but mark it as deprecated. + * Improve error message when attempting to use `None` in data-dependent control flow. + * Add `RaggedTensor.numpy()`. + * Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions. + * Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension. + * Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged. + * Allow `batch_dims==rank(indices)` in `tf.gather`. + * Add support for bfloat16 in `tf.print`. +* `tf.distribute`: + * Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`. +* `tf.keras`: + * Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop. + * Allow `pathlib.Path` paths for loading models via Keras API. +* `tf.function`/AutoGraph: + * AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`. + * Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info. + * AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph. + * Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x. + * Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes. + * Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`. + * You can now iterate over `RaggedTensors` using a for loop inside `tf.function`. +* `tf.lite`: + * Migrated the `tf.lite` C inference API out of experimental into lite/c. + * Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10 + * TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code. + * Refactors the delegate and delegate kernel sources to allow usage in the linter. + * Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled. + * TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`. + * TFLite's unpack op now supports boolean tensor inputs. + * Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder + * Check for large TFLite tensors. + * Fix GPU delegate crash with C++17. + * Add 5D support to TFLite `strided_slice`. + * Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated. + * Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate + * Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar. + * Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar. + * Expose option to limit the number of partitions that will be delegated to `NNAPI`. + * If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version. +* `tf.random`: + * Various random number generation improvements: + * Add a fast path for default `random_uniform` + * `random_seed` documentation improvement. + * `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right. + * Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson` + * `tf.random.stateless_uniform` now supports unbounded sampling of `int` types. +* Math and Linear Algebra: + * Add `tf.linalg.LinearOperatorTridiag`. + * Add `LinearOperatorBlockLowerTriangular` + * Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation. + * Add `tf.math.sobol_sample` op. + * Add `tf.math.xlog1py`. + * Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`. + * Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`. +* TPU Enhancements: + * Refactor `TpuClusterResolver` to move shared logic to a separate pip package. + * Support configuring TPU software version from cloud tpu client. + * Allowed TPU embedding weight decay factor to be multiplied by learning rate. +* XLA Support: + * Add standalone XLA AOT runtime target + relevant .cc sources to pip package. + * Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later. + * `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs. + * Enable `Igamma`, `Igammac` for XLA. +* Deterministic Op Functionality: + * XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled. + * Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!" +* Tracing and Debugging: + * Add source, destination name to `_send` traceme to allow easier debugging. + * Add traceme event to `fastpathexecute`. +* Other: + * Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852) + * Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`. + * Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +372046933, 8bitmp3, aaronhma, Abin Shahab, Aditya Patwardhan, Agoniii, Ahti Kitsik, Alan Yee, Albin Joy, Alex Hoffman, Alexander Grund, Alexandre E. Eichenberger, Amit Kumar Jaiswal, amoitra, Andrew Anderson, Angus-Luo, Anthony Barbier, Anton Kachatkou, Anuj Rawat, archis, Arpan-Dhatt, Arvind Sundararajan, Ashutosh Hathidara, autoih, Bairen Yi, Balint Cristian, Bas Aarts, BashirSbaiti, Basit Ayantunde, Ben Barsdell, Benjamin Gaillard, boron, Brett Koonce, Bryan Cutler, Christian Goll, Christian Sachs, Clayne Robison, comet, Daniel Falbel, Daria Zhuravleva, darsh8200, David Truby, Dayananda-V, deepakm, Denis Khalikov, Devansh Singh, Dheeraj R Reddy, Diederik Van Liere, Diego Caballero, Dominic Jack, dothinking, Douman, Drake Gens, Duncan Riach, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, elzino, Ending2015a, Eric Schweitz, Erik Zettel, Ethan Saadia, Eugene Kuznetsov, Evgeniy Zheltonozhskiy, Ewout Ter Hoeven, exfalso, FAIJUL, Fangjun Kuang, Fei Hu, Frank Laub, Frederic Bastien, Fredrik Knutsson, frreiss, Frédéric Rechtenstein, fsx950223, Gaurav Singh, gbaned, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, Hans Gaiser, Hans Pabst, Haoyu Wu, Harry Slatyer, hsahovic, Hugo, Hugo Sjöberg, IrinaM21, jacco, Jake Tae, Jean-Denis Lesage, Jean-Michel Gorius, Jeff Daily, Jens Elofsson, Jerry Shih, jerryyin, Jin Mingjian, Jinjing Zhou, JKIsaacLee, jojimonv, Jonathan Dekhtiar, Jose Ignacio Gomez, Joseph-Rance, Judd, Julian Gross, Kaixi Hou, Kaustubh Maske Patil, Keunwoo Choi, Kevin Hanselman, Khor Chean Wei, Kilaru Yasaswi Sri Chandra Gandhi, Koan-Sin Tan, Koki Ibukuro, Kristian Holsheimer, kurileo, Lakshay Tokas, Lee Netherton, leike666666, Leslie-Fang-Intel, Li, Guizi, LIUJIAN435, Lukas Geiger, Lyo Nguyen, madisetti, Maher Jendoubi, Mahmoud Abuzaina, Manuel Freiberger, Marcel Koester, Marco Jacopo Ferrarotti, Markus Franke, marload, Mbah-Javis, mbhuiyan, Meng Zhang, Michael Liao, MichaelKonobeev, Michal Tarnowski, Milan Straka, minoring, Mohamed Nour Abouelseoud, MoussaMM, Mrinal Jain, mrTsjolder, Måns Nilsson, Namrata Bhave, Nicholas Gao, Niels Ole Salscheider, nikochiko, Niranjan Hasabnis, Nishidha Panpaliya, nmostafa, Noah Trenaman, nuka137, Officium, Owen L - Sfe, Pallavi G, Paul Andrey, Peng Sun, Peng Wu, Phil Pearl, PhilipMay, pingsutw, Pooya Davoodi, PragmaTwice, pshiko, Qwerty71, R Gomathi, Rahul Huilgol, Richard Xiao, Rick Wierenga, Roberto Rosmaninho, ruchit2801, Rushabh Vasani, Sami, Sana Damani, Sarvesh Dubey, Sasan Jafarnejad, Sergii Khomenko, Shane Smiskol, Shaochen Shi, sharkdtu, Shawn Presser, ShengYang1, Shreyash Patodia, Shyam Sundar Dhanabalan, Siju Samuel, Somyajit Chakraborty Sam, Srihari Humbarwadi, srinivasan.narayanamoorthy, Srishti Yadav, Steph-En-M, Stephan Uphoff, Stephen Mugisha, SumanSudhir, Taehun Kim, Tamas Bela Feher, TengLu, Tetragramm, Thierry Herrmann, Tian Jin, tigertang, Tom Carchrae, Tom Forbes, Trent Lo, Victor Peng, vijayphoenix, Vincent Abriou, Vishal Bhola, Vishnuvardhan Janapati, vladbataev, VoVAllen, Wallyss Lima, Wen-Heng (Jack) Chung, wenxizhu, William D. Irons, William Zhang, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, Yasir Modak, Yasuhiro Matsumoto, Yaxun (Sam) Liu, Yong Tang, Ytyt-Yt, yuan, Yuan Mingshuai, Yuan Tang, Yuki Ueda, Yusup, zhangshijin, zhuwenxi + # Release 2.0.1 ## Bug Fixes and Other Changes diff --git a/configure.py b/configure.py index ac9ed0c4d88..945c3036a8d 100644 --- a/configure.py +++ b/configure.py @@ -144,7 +144,7 @@ def write_to_bazelrc(line): def write_action_env_to_bazelrc(var_name, var): - write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var))) + write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var))) def run_shell(cmd, allow_non_zero=False, stderr=None): @@ -205,7 +205,7 @@ def setup_python(environ_cp): # Get PYTHON_BIN_PATH, default is the current running python. default_python_bin_path = sys.executable ask_python_bin_path = ('Please specify the location of python. [Default is ' - '%s]: ') % default_python_bin_path + '{}]: ').format(default_python_bin_path) while True: python_bin_path = get_from_env_or_user_or_default(environ_cp, 'PYTHON_BIN_PATH', @@ -215,9 +215,10 @@ def setup_python(environ_cp): if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK): break elif not os.path.exists(python_bin_path): - print('Invalid python path: %s cannot be found.' % python_bin_path) + print('Invalid python path: {} cannot be found.'.format(python_bin_path)) else: - print('%s is not executable. Is it the python binary?' % python_bin_path) + print('{} is not executable. Is it the python binary?'.format( + python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = '' # Convert python path to Windows style before checking lib and version @@ -236,7 +237,7 @@ def setup_python(environ_cp): default_python_lib_path = python_lib_paths[0] python_lib_path = get_input( 'Please input the desired Python library path to use. ' - 'Default is [%s]\n' % python_lib_paths[0]) + 'Default is [{}]\n'.format(python_lib_paths[0])) if not python_lib_path: python_lib_path = default_python_lib_path environ_cp['PYTHON_LIB_PATH'] = python_lib_path @@ -252,7 +253,7 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) + write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = python_bin_path # If choosen python_lib_path is from a path specified in the PYTHONPATH @@ -266,7 +267,7 @@ def setup_python(environ_cp): with open( os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: - f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) + f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path)) def reset_tf_configure_bazelrc(): @@ -320,11 +321,12 @@ def get_var(environ_cp, Raise the error to avoid infinitely looping. """ if not question: - question = 'Do you wish to build TensorFlow with %s support?' % query_item + question = 'Do you wish to build TensorFlow with {} support?'.format( + query_item) if not yes_reply: - yes_reply = '%s support will be enabled for TensorFlow.' % query_item + yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item) if not no_reply: - no_reply = 'No %s' % yes_reply + no_reply = 'No {}'.format(yes_reply) yes_reply += '\n' no_reply += '\n' @@ -368,7 +370,7 @@ def get_var(environ_cp, print(no_reply) var = False else: - print('Invalid selection: %s' % user_input_origin) + print('Invalid selection: {}'.format(user_input_origin)) return var @@ -479,13 +481,13 @@ def check_bazel_version(min_version, max_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell( - ['bazel', '--batch', '--bazelrc=/dev/null', 'version']) - for line in curr_version.split('\n'): - if 'Build label: ' in line: - curr_version = line.split('Build label: ')[1] - break + stderr = open(os.devnull, 'wb') + curr_version = run_shell(['bazel', '--version'], + allow_non_zero = True, + stderr = stderr) + if curr_version.startswith('bazel '): + curr_version = curr_version.split('bazel ')[1] min_version_int = convert_version_to_int(min_version) curr_version_int = convert_version_to_int(curr_version) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f2018220a56..ab4316d5ed0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -517,6 +517,7 @@ package_group( "//perftools/accelerators/xprof/api/...", "//third_party/py/autograph/...", "//third_party/swift/tensorflow/x10/...", + "//third_party/swift/tensorflow_apis/...", "//tensorflow/...", "//tensorflow_estimator/python/estimator/...", "//tensorflow_models/official/...", @@ -529,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list") # TODO(b/154762408) Remove this package group once it's no longer needed. package_group(name = "composite_tensor_whitelist") +# Packages that use private types symbols, until they are exported. +# TODO(b/154650521) Remove. +package_group( + name = "types_whitelist", + packages = ["//learning/deepmind/tensorflow/replicator/..."], +) + filegroup( name = "intel_binary_blob", data = if_mkl_ml( diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 1c4c0d1e06a..05d5f9a3ed2 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -58,6 +58,7 @@ filegroup( name = "pywrap_required_hdrs", srcs = [ "c_api_internal.h", + "conversion_macros.h", "python_api.h", "tensor_interface.h", "tf_status_helper.h", @@ -84,7 +85,14 @@ tf_cuda_library( ], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//tensorflow:chromiumos": [ + ":tf_attrtype", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform:platform", ], "//conditions:default": [ ":tf_attrtype", @@ -174,7 +182,7 @@ tf_cuda_library( ":tf_status_internal", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":tf_status", @@ -211,7 +219,7 @@ tf_cuda_library( ], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ "//tensorflow/core:lib", @@ -224,12 +232,13 @@ cc_library( srcs = ["tf_status.cc"], hdrs = ["tf_status.h"], visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":tf_status_internal", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tf_status_internal", "//tensorflow/core:lib", ], }), @@ -251,10 +260,15 @@ cc_library( name = "tensor_interface", hdrs = ["tensor_interface.h"], visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + ], + "//conditions:default": [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + }), ) cc_library( @@ -264,7 +278,7 @@ cc_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ "//tensorflow/core:framework", @@ -278,16 +292,17 @@ cc_library( srcs = ["tf_tensor.cc"], hdrs = ["tf_tensor.h"], visibility = ["//visibility:public"], - deps = select({ + deps = [ + ":tensor_interface", + ":tf_datatype", + ":tf_status", + ":tf_status_helper", + ":tf_tensor_internal", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tensor_interface", - ":tf_datatype", - ":tf_status", - ":tf_status_helper", - ":tf_tensor_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -303,14 +318,15 @@ tf_cuda_library( "tf_tensor_internal.h", ], visibility = ["//tensorflow:internal"], - deps = select({ + deps = [ + ":tensor_interface", + ":tf_datatype", + ":tf_status", + ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], "//conditions:default": [ - ":tensor_interface", - ":tf_datatype", - ":tf_status", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:casts", @@ -418,7 +434,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", @@ -449,7 +465,7 @@ tf_cuda_library( ] + select({ "//tensorflow:android": [ ":c_api_internal", - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api_internal", @@ -476,7 +492,7 @@ tf_cuda_library( ":tf_status_helper", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ "//tensorflow/core:framework", @@ -532,6 +548,7 @@ tf_cuda_cc_test( "//conditions:default": [], }), tags = [ + "no_windows", # TODO(b/155444728) "noasan", ], # We must ensure that the dependencies can be dynamically linked since diff --git a/tensorflow/c/conversion_macros.h b/tensorflow/c/conversion_macros.h index ce8adfadb26..d1f99b7b5b0 100644 --- a/tensorflow/c/conversion_macros.h +++ b/tensorflow/c/conversion_macros.h @@ -16,15 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_C_CONVERSION_MACROS_H_ #define TENSORFLOW_C_CONVERSION_MACROS_H_ -#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ - inline cpp_impl *unwrap(wrapper *w) { \ - return reinterpret_cast(w); \ - } \ - \ - inline const cpp_impl *unwrap(const wrapper *w) { \ - return reinterpret_cast(w); \ - } \ - \ - inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast(i); } +#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ + inline cpp_impl *unwrap(wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline const cpp_impl *unwrap(const wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast(i); } \ + inline const wrapper *wrap(const cpp_impl *i) { \ + return reinterpret_cast(i); \ + } #endif // TENSORFLOW_C_CONVERSION_MACROS_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 42a31444380..fe4d5ac6ffe 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -35,7 +35,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":context_interface", @@ -246,6 +246,7 @@ cc_library( "//tensorflow:internal", ], deps = [ + "//tensorflow/c:conversion_macros", "//tensorflow/c:tf_status", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", @@ -316,7 +317,8 @@ tf_cuda_cc_test( ], extra_copts = tfe_xla_copts(), tags = [ - "guitar", + "noguitar", # TODO(b/155445984): flaky + #"guitar", "multi_gpu", ], deps = [ @@ -344,7 +346,10 @@ tf_cuda_cc_test( srcs = [ "c_api_remote_test.cc", ], + # TODO(b/136478427): Figure out how to correctly shut the server down + args = ["--heap_check=local"], extra_copts = tfe_xla_copts(), + tags = ["noasan"], # leaks gRPC server instances deps = [ ":c_api", ":c_api_experimental", @@ -362,6 +367,34 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "c_api_cluster_test", + size = "small", + srcs = [ + "c_api_cluster_test.cc", + ], + # TODO(b/136478427): Figure out how to correctly shut the server down + args = ["--heap_check=local"], + extra_copts = tfe_xla_copts(), + tags = ["noasan"], # leaks gRPC server instances + deps = [ + ":c_api", + ":c_api_experimental", + ":c_api_internal", + ":c_api_test_util", + ":tfe_tensorhandle_internal", + "//tensorflow/c:c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/platform:env", + "@com_google_absl//absl/strings", + ], +) + tf_cuda_library( name = "c_api_experimental", srcs = [ @@ -379,7 +412,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api", @@ -415,6 +448,8 @@ tf_cuda_library( "//conditions:default": [], }) + [ "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_status_helper", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -575,7 +610,6 @@ filegroup( ], exclude = [ "c_api_experimental.cc", - "*c_api_tfrt*", "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 540efe9dcc0..5c01ccb82bb 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" #ifdef PLATFORM_GOOGLE -#include "tensorflow/c/eager/c_api_tfrt.h" +#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #endif #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -500,6 +500,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( grpc_server->master_env()->worker_cache->GetEagerClientCache( &remote_eager_workers)); + // For cluster update, use a status group to aggregate statuses from + // * adding and removing remote devices + // * creating remote contexts on newly added workers + // * updating remote contexts on existing workers + // * updating the master context + // Note that we should not return immediately on errors in the middle of these + // updates to prevent cluster from having inconsistent context views. + // + // Unused if `reset_context` is True. + tensorflow::StatusGroup sg; + // When updating an existing context, populate the following lists with: // * added_workers: set(remote_workers) - set(curr_remote_workers) // * removed_workers: set(curr_remote_workers) - set(remote_workers) @@ -535,7 +546,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( DifferentiateWorkerLists(&curr_remote_workers, &remote_workers, &added_workers, &removed_workers, &existing_workers); - LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers( + sg.Update(GetReplacedFromExistingWorkers( &existing_workers, context_id, context->GetContextViewId(), server_def, remote_eager_workers.get(), &replaced_workers)); if (VLOG_IS_ON(1)) { @@ -559,11 +570,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( existing_workers.end()); } } - LOG_AND_RETURN_IF_ERROR( - RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); - LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr( - added_workers, grpc_server->master_env()->worker_cache, - remote_device_mgr)); + sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); + sg.Update(AddRemoteDevicesToMgr(added_workers, + grpc_server->master_env()->worker_cache, + remote_device_mgr)); } std::vector cluster_device_attributes; @@ -584,7 +594,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( } // Initialize remote eager workers. - // TODO(b/138847548) Create remote eager contexts in async mode by default. if (reset_context) { LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( ctx, remote_workers, context_id, context_view_id, keep_alive_secs, @@ -596,7 +605,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // existing workers to also have the updated context_view_id, so // we must set their context_view_id to the existing master's // context_view_id + 1. - LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + sg.Update(CreateRemoteContexts( ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), context->LazyCopyFunctionRemoteInputs(), base_request)); @@ -606,10 +615,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( VLOG(1) << "Updating cluster with existing worker " << w; } } - LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts( - ctx, existing_workers, added_workers, removed_workers, context_id, - context_view_id + 1, server_def, remote_eager_workers.get(), - base_request)); + sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, + removed_workers, context_id, + context_view_id + 1, server_def, + remote_eager_workers.get(), base_request)); } } @@ -645,13 +654,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // GrpcServer cannot be destroyed after it is started. LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); } else { - LOG_AND_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->UpdateSession( - session_name, server_def, base_request.cluster_device_attributes(), - /*isolate_session_state=*/true)); - LOG_AND_RETURN_IF_ERROR( - context->UpdateRemoteMaster(context_id, std::move(remote_eager_workers), - added_workers, removed_workers)); + sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession( + session_name, server_def, base_request.cluster_device_attributes(), + /*isolate_session_state=*/true)); + sg.Update(context->UpdateRemoteMaster(context_id, + std::move(remote_eager_workers), + added_workers, removed_workers)); + LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); } #undef LOG_AND_RETURN_IF_ERROR @@ -685,8 +694,13 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { #ifdef PLATFORM_GOOGLE - status->status = tensorflow::Status::OK(); - return tensorflow::wrap(new tfrt::ContextInterface()); + tfrt::SmallVector op_handler_chains; + tfrt::SmallVector device_attributes; + status->status = tfrt::ListOpHandlerChains( + opts->session_options.options, &op_handler_chains, &device_attributes); + if (!status->status.ok()) return nullptr; + return tensorflow::wrap( + new tfrt::ContextInterface(op_handler_chains, device_attributes)); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; @@ -910,7 +924,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( context->GetDevicePlacementPolicy()); } -TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { +TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; @@ -1458,20 +1472,21 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } // namespace void TFE_ContextStartStep(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->StartStep(); + tensorflow::unwrap(ctx)->StartStep(); } void TFE_ContextEndStep(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->EndStep(); + tensorflow::unwrap(ctx)->EndStep(); +} + +const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) { + return tensorflow::wrap( + &OperationFromInterface(tensorflow::unwrap(op))->Attrs()); } void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { tensorflow::AttrValueMap m; - attrs->attributes->FillAttrValueMap(&m); + tensorflow::unwrap(attrs)->FillAttrValueMap(&m); tensorflow::EagerOperation* operation = OperationFromInterface(tensorflow::unwrap(op)); tensorflow::AttrBuilder* destination = operation->MutableAttrs(); @@ -1483,8 +1498,8 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf, TF_Status* status) { tensorflow::NameAttrList name_and_attrs; - attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr()); - name_and_attrs.set_name(attrs->attributes->op_name()); + tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr()); + name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name()); status->status = MessageToBuffer(name_and_attrs, buf); } @@ -1605,9 +1620,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { } std::vector outputs(*num_retvals); TF_Status status; - TFE_OpAttrs attributes(&op->Attrs()); device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(), - &attributes, num_retvals, outputs.data(), &status, info_); + wrap(&op->Attrs()), num_retvals, outputs.data(), &status, + info_); if (status.status.ok()) { for (int i = 0; i < *num_retvals; ++i) { retvals[i] = tensorflow::TensorHandleFromInterface( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 070b3a9bb60..5afe3047dd7 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, // placed in memory of different devices or remote address spaces. typedef struct TFE_TensorHandle TFE_TensorHandle; -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status); // Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc new file mode 100644 index 00000000000..252a0408758 --- /dev/null +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -0,0 +1,433 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace { + +using ::tensorflow::string; + +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name(job_name); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name(job_name); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} + +void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) { + tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->at(task_index) = + tensorflow::strings::StrCat("localhost:", port); +} + +void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, + const std::vector& expected_values) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + std::unique_ptr actual_values(new float[expected_values.size()]); + EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); + memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + + for (int i = 0; i < expected_values.size(); i++) { + EXPECT_EQ(expected_values[i], actual_values[i]) + << "Mismatch in expected values at (zero-based) index " << i; + } +} + +void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, + const char* remote_device_name, + const char* local_device_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); + + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = + TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); + + TFE_DeleteTensorHandle(retval_task0); + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + TF_DeleteStatus(status); +} + +// Read the value of variable `var` and save it into `out_value`. +void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var, + TFE_TensorHandle** out_value) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 1; + TFE_Execute(op, out_value, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + TF_DeleteStatus(status); +} + +void TestRemoteExecuteChangeServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + + // Update the server def with a new set of names (worker instead of + // localhost). + tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); + serialized = updated_server_def.SerializeAsString(); + + updated_server_def.set_task_index(1); + tensorflow::Status s = tensorflow::GrpcServer::Create( + updated_server_def, tensorflow::Env::Default(), &worker_server); + ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(worker_server->Start().ok()); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Create a new tensor_handle. + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx); + + // Check that copying it to the old remote device (named localhost) fails. + TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Copying and executing on the new remote device works. + const char new_remote_device_name[] = + "/job:worker/replica:0/task:1/device:CPU:0"; + const char new_local_device_name[] = + "/job:worker/replica:0/task:0/device:CPU:0"; + + auto* h0_task1_new = TFE_TensorHandleCopyToDevice( + h0_task0_new, ctx, new_remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0_new); + TFE_DeleteTensorHandle(h0_task1_new); + + CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, + new_local_device_name); + + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + + TF_DeleteStatus(status); + + TFE_DeleteContext(ctx); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteChangeServerDef) { + TestRemoteExecuteChangeServerDef(false); +} +TEST(CAPI, RemoteExecuteChangeServerDefAsync) { + TestRemoteExecuteChangeServerDef(true); +} + +void TestRemoteExecuteUpdateServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteUpdateServerDef) { + TestRemoteExecuteUpdateServerDef(false); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefAsync) { + TestRemoteExecuteUpdateServerDef(true); +} + +void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + + TFE_TensorHandle* value_handle = nullptr; + ReadVariable(ctx, var_handle1, &value_handle); + CheckTFE_TensorHandleHasFloats(value_handle, {2}); + TFE_DeleteTensorHandle(value_handle); + + // Start a new worker to replace task:1 + ReplaceTaskInServerDef(&server_def, 1); + server_def.set_task_index(1); + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + // Update server def to replace the remote device with the device info on the + // new worker (different incarnation ID). + server_def.set_task_index(0); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // The device of var_handle0 is local device which is the same before and + // after cluster update. Remove resource with valid device should succeed. + TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle0, status); + TFE_OpSetDevice(op, dev0_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + int num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + // The device of var_handle1 is remote device, which was replaced during + // cluster update. Removing resource with invalid device should fail + // gracefully (i.e., with error status) instead of crashing with segfaults. + op = TFE_NewOp(ctx, "DestroyResourceOp", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, var_handle1, status); + TFE_OpSetDevice(op, dev1_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) { + TestRemoteExecuteUpdateServerDefResourceAccess(false); +} + +TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) { + TestRemoteExecuteUpdateServerDefResourceAccess(true); +} + +void TestRemoteExecuteUpdateServerDefWithFailures(bool async) { + // Fail fast on GetStatus requests so we can get errors instead of timeout + // when updating cluster with non-exsitent worker + tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1); + + tensorflow::ServerDef server_def = GetServerDef(2); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + // Adding a non-existent remote worker to cluster def. This should cause the + // UpdateServerDef call to fail. + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->mutable_job(0); + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {2, tensorflow::strings::StrCat("localhost:", port)}); + server_def.set_task_index(0); + string serialized_update = server_def.SerializeAsString(); + TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(), + serialized_update.size(), status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Even after the prevoiusly failed cluster update, another update and op + // execution should work fine as long as the provided server_def is valid. + TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server.release(); + tensorflow::unsetenv("GRPC_FAIL_FAST"); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) { + TestRemoteExecuteUpdateServerDefWithFailures(false); +} + +TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { + TestRemoteExecuteUpdateServerDefWithFailures(true); +} + +} // namespace diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 820650e315f..0d71b11531b 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -638,3 +639,35 @@ TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t, return tensorflow::wrap( tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor)); } + +TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, + TFE_TensorHandle** handles, + int* num_handles, + TF_Status* status) { + std::vector tensor_handles; + tensor_handles.reserve(*num_handles); + for (int i = 0; i < *num_handles; ++i) { + tensor_handles.push_back( + tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i]))); + } + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + tensorflow::TensorHandle* handle = nullptr; + status->status = tensorflow::TensorHandle::CreatePackedHandle( + std::move(tensor_handles), context, &handle); + return tensorflow::wrap(handle); +} + +void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetAllowSoftPlacement(enable); +} + +void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + context->SetLogDevicePlacement(enable); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index d1e99d86180..1b8efe61ee0 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -431,6 +431,9 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, // A reference to an op's name -> attribute mapping typedef struct TFE_OpAttrs TFE_OpAttrs; +// Fetch a reference to `op`'s attributes. The returned reference is only valid +// while `op` is alive. +const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op); // Add attributes in `attrs` to `op`. // // Does not overwrite or update existing attributes, but adds new ones. @@ -538,6 +541,26 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor( TFE_Context* ctx, TF_Tensor* t, TF_Status* status); +// Create a packed TensorHandle with the given list of TensorHandles. +// If `handles` are on the same device, assign the same device to the packed +// handle; if `handles` are on different deivces, assign a CompositeDevice to +// it. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( + TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, + TF_Status* status); + +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 7c6836af69b..d04e4ef4212 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -168,7 +168,11 @@ string MatMulFunction() { return def.SerializeAsString(); } -void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { +// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one +// which creates a remote remote input, to simulate a scenario that the remote +// input is not ready when we start running an op or a function. +void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, + bool heavy_load_on_streaming_rpc) { tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. @@ -193,47 +197,62 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx); + std::vector handles_task0; + if (heavy_load_on_streaming_rpc) { + // Send 50 tensor copy requests to simulate that there have been some RPC + // requests been enqueued. + for (int i = 0; i < 50; ++i) { + handles_task0.push_back(TestMatrixTensorHandle(ctx)); + } + } const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + std::vector handles_task2; + for (auto* h_task0 : handles_task0) { + handles_task2.push_back( + TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status)); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } + auto* h1_task2 = TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_Op* matmul = nullptr; if (func) { string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); matmul = TFE_NewOp(ctx, "MatMulFunction", status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_OpAddInput(matmul, h0_task0, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_OpAddInput(matmul, h1_task2, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); } else { // Handles are on task0 (local), and task2, but op is on task1. matmul = MatMulOp(ctx, h0_task0, h1_task2); } if (remote) { TFE_OpSetDevice(matmul, task1_name, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); } else if (!async) { // Set the local device to CPU to easily validate mirroring string cpu_device_name; ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU")); TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); auto remote_arg = tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); // The input handles should never change since they have been mirrored. @@ -243,7 +262,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { TFE_TensorHandle* retvals[1]; int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); // TODO(gjn): Add support for waiting on async local mirrors if (!remote && !async) { @@ -255,10 +274,10 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { auto* retval_task0 = TFE_TensorHandleCopyToDevice( retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_DeleteTensorHandle(retval_task0); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -273,12 +292,18 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { TFE_DeleteTensorHandle(h1_task0); TFE_DeleteTensorHandle(h1_task2); TFE_DeleteTensorHandle(retvals[0]); + for (auto* h : handles_task0) { + TFE_DeleteTensorHandle(h); + } + for (auto* h : handles_task2) { + TFE_DeleteTensorHandle(h); + } TFE_DeleteOp(matmul); TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_DeleteExecutor(executor); if (func) { TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); @@ -293,22 +318,260 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) { } TEST(CAPI, RemoteExecuteSilentCopies) { - TestRemoteExecuteSilentCopies(false, true, false); + TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true, + /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesAsync) { - TestRemoteExecuteSilentCopies(true, true, false); + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) { - TestRemoteExecuteSilentCopies(true, true, true); + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesLocal) { - TestRemoteExecuteSilentCopies(false, false, false); + TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false, + /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) { - TestRemoteExecuteSilentCopies(true, false, false); + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, + /*func=*/false, + /*heavy_load_on_streaming_rpc=*/false); } TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { - TestRemoteExecuteSilentCopies(true, false, true); + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, + /*heavy_load_on_streaming_rpc=*/false); +} +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { + // A remote input may be not ready when we start running a function. Test that + // the function execution should wait until the remote input is ready. + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, + /*heavy_load_on_streaming_rpc=*/true); +} + +// Add the values of three variables on three different tasks. +string AddVariablesFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'AddVariablesFunction'" + " input_arg {" + " name: 'var'" + " type: DT_RESOURCE" + " }" + " output_arg {" + " name: 'sum'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'read0'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:0/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'read1'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:1/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'read2'" + " op: 'ReadVariableOp'" + " input: 'var'" + " device: '/job:localhost/replica:0/task:2/device:CPU:0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add1'" + " op: 'Add'" + " input: 'read0:value:0'" + " input: 'read1:value:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add2'" + " op: 'Add'" + " input: 'add1:z:0'" + " input: 'read2:value:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'sum'" + " value: 'add2:z:0'" + " }", + &def)); + return def.SerializeAsString(); +} + +void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) { + TF_Status* status = TF_NewStatus(); + TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(op, var_handle, status); + TFE_TensorHandle* is_initialized[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(op, &is_initialized[0], &num_retvals, status); + CHECK_EQ(1, num_retvals); + TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status); + bool initialized = false; + memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(initialized, true); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(is_initialized[0]); + TFE_DeleteOp(op); + delete status; +} + +void TestFunctionWithPackedInput(const bool remote) { + tensorflow::ServerDef server_def = GetServerDef(3); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + + server_def.set_task_index(2); + std::unique_ptr worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(/*enable=*/true)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0"; + const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; + const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + // Create one variable per task. + TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name); + TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name); + TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name); + + // Add a sync point in order to make sure that variables have been initialized + // before the function execution starts. + // TODO(b/155789951): Remove once b/155789951 is fixed. + VarIsInitialized(ctx, h1); + VarIsInitialized(ctx, h2); + + // Pack 3 variable handles into one TFE_TensorHandle. + int num_replicas = 3; + std::vector handles = {h0, h1, h2}; + TFE_TensorHandle* packed_handle = + TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE); + EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0); + EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1); + + const string composite_device_name = + "/job:localhost/replica:0/task:0/device:COMPOSITE:0"; + EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status), + composite_device_name); + EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status), + composite_device_name); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // Register and run a function which returns the sum of 3 variables. + const string function_def = AddVariablesFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(func, packed_handle, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + if (remote) { + TFE_OpSetDevice(func, task1_name, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(func, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TFE_DeleteOp(func); + TFE_DeleteTensorHandle(packed_handle); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + float sum = 0; + EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t)); + memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(sum, 6.0); + + TFE_DeleteTensorHandle(h0); + TFE_DeleteTensorHandle(h1); + TFE_DeleteTensorHandle(h2); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_DeleteExecutor(executor); + TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status); + TFE_DeleteContext(ctx); + + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} + +TEST(CAPI, TestLocalFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/false); +} + +TEST(CAPI, TestRemoteFunctionWithPackedInput) { + TestFunctionWithPackedInput(/*remote=*/true); } void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { @@ -381,150 +644,4 @@ TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) { TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) { TestRemoteExecuteDeleteContextWithOutstandingRPC(true); } - -void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, - const std::vector& expected_values) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - std::unique_ptr actual_values(new float[expected_values.size()]); - EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); - memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); - TF_DeleteTensor(t); - - for (int i = 0; i < expected_values.size(); i++) { - EXPECT_EQ(expected_values[i], actual_values[i]) - << "Mismatch in expected values at (zero-based) index " << i; - } -} - -void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, - const char* remote_device_name, - const char* local_device_name) { - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx); - - TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); - TFE_OpSetDevice(matmul, remote_device_name, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_TensorHandle* retvals[1]; - int num_retvals = 1; - TFE_Execute(matmul, &retvals[0], &num_retvals, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - auto* retval_task0 = - TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); - - TFE_DeleteTensorHandle(retval_task0); - TFE_DeleteTensorHandle(h0_task0); - TFE_DeleteTensorHandle(retvals[0]); - - TFE_DeleteOp(matmul); - - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - TF_DeleteStatus(status); -} - -void TestRemoteExecuteChangeServerDef(bool async) { - tensorflow::ServerDef server_def = GetServerDef(2); - - // This server def has the task index set to 0. - string serialized = server_def.SerializeAsString(); - - server_def.set_task_index(1); - - std::unique_ptr worker_server; - ASSERT_TRUE(tensorflow::GrpcServer::Create( - server_def, tensorflow::Env::Default(), &worker_server) - .ok()); - ASSERT_TRUE(worker_server->Start().ok()); - - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(async)); - TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); - TFE_Context* ctx = TFE_NewContext(opts, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - const char remote_device_name[] = - "/job:localhost/replica:0/task:1/device:CPU:0"; - const char local_device_name[] = - "/job:localhost/replica:0/task:0/device:CPU:0"; - CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); - - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - // TODO(b/136478427): Figure out how to correctly shut the server down. - worker_server.release(); - - // Update the server def with a new set of names (worker instead of - // localhost). - tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); - serialized = updated_server_def.SerializeAsString(); - - updated_server_def.set_task_index(1); - tensorflow::Status s = tensorflow::GrpcServer::Create( - updated_server_def, tensorflow::Env::Default(), &worker_server); - ASSERT_TRUE(s.ok()) << s.error_message(); - ASSERT_TRUE(worker_server->Start().ok()); - - TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - // Create a new tensor_handle. - TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx); - - // Check that copying it to the old remote device (named localhost) fails. - TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); - EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); - - // Copying and executing on the new remote device works. - const char new_remote_device_name[] = - "/job:worker/replica:0/task:1/device:CPU:0"; - const char new_local_device_name[] = - "/job:worker/replica:0/task:0/device:CPU:0"; - - auto* h0_task1_new = TFE_TensorHandleCopyToDevice( - h0_task0_new, ctx, new_remote_device_name, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TFE_DeleteTensorHandle(h0_task0_new); - TFE_DeleteTensorHandle(h0_task1_new); - - CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, - new_local_device_name); - - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - - TF_DeleteStatus(status); - - TFE_DeleteContext(ctx); - - // TODO(b/136478427): Figure out how to correctly shut the server down. - worker_server.release(); -} - -TEST(CAPI, RemoteExecuteChangeServerDef) { - TestRemoteExecuteChangeServerDef(false); -} -TEST(CAPI, RemoteExecuteChangeServerDefAsync) { - TestRemoteExecuteChangeServerDef(true); -} - } // namespace diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 0e4183dad16..724176505ba 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1132,51 +1132,6 @@ void BM_ExecuteFunction(int iters, int async) { } BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1); -TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, - TF_Status* status) { - // Create the variable handle. - TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", "", 0); - TFE_OpSetAttrString(op, "shared_name", "", 0); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_TensorHandle* var_handle = nullptr; - int num_retvals = 1; - TFE_Execute(op, &var_handle, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(1, num_retvals); - - // Assign 'value' to it. - op = TFE_NewOp(ctx, "AssignVariableOp", status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle, status); - - // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. - std::unique_ptr t( - TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); - memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); - - std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get(), status), - TFE_DeleteTensorHandle); - if (TF_GetCode(status) != TF_OK) return nullptr; - - TFE_OpAddInput(op, value_handle.get(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - - num_retvals = 0; - TFE_Execute(op, nullptr, &num_retvals, status); - TFE_DeleteOp(op); - if (TF_GetCode(status) != TF_OK) return nullptr; - CHECK_EQ(0, num_retvals); - - return var_handle; -} - TEST(CAPI, Variables) { // Variables use resource handles, so this is really a test for resource // tensor handling. @@ -1186,7 +1141,7 @@ TEST(CAPI, Variables) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); + TFE_TensorHandle* var_handle = TestVariable(ctx, 12.0); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); @@ -1227,7 +1182,7 @@ void BM_ReadVariable(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); + TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); @@ -1248,6 +1203,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; + TFE_OpAddInput(op, var_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } tensorflow::testing::StopTiming(); TFE_DeleteOp(op); @@ -1591,15 +1548,11 @@ TEST(CAPI, TestTFE_OpAddAttrs) { TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); - // There is currently no API to fetch attributes from an operation, fetching - // happens only as an implementation detail of custom devices. - tensorflow::EagerOperation* operation = - OperationFromInterface(tensorflow::unwrap(var_op)); - TFE_OpAttrs attributes{&operation->Attrs()}; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op); TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status); TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT); - TFE_OpAddAttrs(copy_op, &attributes); + TFE_OpAddAttrs(copy_op, attributes); unsigned char is_list = 0; ASSERT_EQ(TF_ATTR_TYPE, TFE_OpGetAttrType(copy_op, "dtype", &is_list, status)); @@ -1631,14 +1584,10 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpSetAttrType(var_op, "dtype", TF_INT64); TFE_OpSetAttrShape(var_op, "shape", {}, 0, status); - // There is currently no API to fetch attributes from an operation, fetching - // happens only as an implementation detail of custom devices. - tensorflow::EagerOperation* operation = - OperationFromInterface(tensorflow::unwrap(var_op)); - TFE_OpAttrs attributes{&operation->Attrs()}; + const TFE_OpAttrs* attributes = TFE_OpGetAttrs(var_op); TF_Buffer* serialized_attr_values = TF_NewBuffer(); - TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status); + TFE_OpAttrsSerialize(attributes, serialized_attr_values, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); tensorflow::NameAttrList name_and_attrs; ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data, diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index e67e17963b3..29b624b8537 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -133,6 +133,58 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) { return th; } +TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name) { + TF_Status* status = TF_NewStatus(); + // Create the variable handle. + TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpSetAttrShape(op, "shape", {}, 0, status); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); + if (!device_name.empty()) { + TFE_OpSetDevice(op, device_name.c_str(), status); + } + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_TensorHandle* var_handle = nullptr; + int num_retvals = 1; + TFE_Execute(op, &var_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(1, num_retvals); + + // Assign 'value' to it. + op = TFE_NewOp(ctx, "AssignVariableOp", status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + + // Convert 'value' to a TF_Tensor then a TFE_TensorHandle. + std::unique_ptr t( + TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor); + memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); + + std::unique_ptr + value_handle(TFE_NewTensorHandle(t.get(), status), + TFE_DeleteTensorHandle); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_OpAddInput(op, value_handle.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + num_retvals = 0; + TFE_Execute(op, nullptr, &num_retvals, status); + TFE_DeleteOp(op); + if (TF_GetCode(status) != TF_OK) return nullptr; + CHECK_EQ(0, num_retvals); + + TF_DeleteStatus(status); + + return var_handle; +} + TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 11ae6d1181b..4c43f8d5833 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -42,6 +42,11 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx); // Return a tensor handle containing a 3x2 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx); +// Return a variable handle referring to a variable with the given initial value +// on the given device. +TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name = ""); + // Return an add op multiplying `a` by `b`. TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 68afffb28b4..e5030a602b3 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" @@ -26,6 +28,51 @@ using tensorflow::string; using tensorflow::internal::OutputList; using tensorflow::internal::unwrap; +namespace tensorflow { +namespace internal { +typedef absl::flat_hash_map FactoriesMap; + +static FactoriesMap& GetFactories() { + static FactoriesMap* factories = new FactoriesMap; + return *factories; +} + +static const char* default_factory = ""; + +void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { + assert((!GetFactories().count(name)) || + (GetFactories()[name] == factory) && + "Duplicate tracing factory registration"); + GetFactories()[name] = factory; +} + +void SetDefaultTracingEngine(const char* name) { default_factory = name; } + +static ExecutionContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { + auto entry = GetFactories().find(default_factory); + if (entry != GetFactories().end()) return entry->second(fn_name, s); + string msg = absl::StrCat( + "No tracing engine factory has been registered with the key '", + default_factory, "' (available: "); + // Ensure deterministic (sorted) order in the error message + std::set factories_sorted; + for (const auto& factory : GetFactories()) + factories_sorted.insert(factory.first); + const char* comma = ""; + for (const string& factory : factories_sorted) { + msg += comma + factory; + comma = ", "; + } + msg += ")"; + + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; +} + +} // end namespace internal +} // end namespace tensorflow + // ============================================================================= // Public C API entry points // @@ -36,6 +83,28 @@ using tensorflow::internal::unwrap; // // ============================================================================= +void TF_SetTracingImplementation(const char* name) { + tensorflow::internal::SetDefaultTracingEngine(name); +} + +// Creates a new TensorFlow function, it is an execution context attached to a +// given tracing context. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) { + return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s)); +} + +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList* outputs, TF_Status* s) { + auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s)); + TF_DeleteExecutionContext(ctx); + return func; +} + +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s) { + return wrap(unwrap(func)->AddParameter(dtype, s)); +} + void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); } TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { @@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) { TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) { return wrap(unwrap(o)->outputs[i]); } +void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, + TF_Status* s) { + unwrap(o)->outputs.push_back(unwrap(tensor)); +} void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, TF_Status* s) { diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index be8fc64c2e1..86c59a7f625 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp; // setting functional attributes of other composite ops e.g. control flow. typedef struct TF_AbstractFunction TF_AbstractFunction; -// Creates a context for tracing the execution of operations into a function. -TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s); +// This allows the client to swap the implementation of the tracing engine. +// Any future call to TF_CreateFunction will use the implementation defined +// here. +void TF_SetTracingImplementation(const char* name); + +// Creates a new TensorFlow function. A Function is an execution context, and as +// such it can trace operations through TF_ExecuteOperation. After completing +// tracing, a function can be obtained by TF_FinalizeFunction. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status); // Creates a context for eager execution of operations. TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, TF_Status* s); - void TF_DeleteExecutionContext(TF_ExecutionContext*); +// Add a new parameter to a TensorFlow Function. +// TODO(aminim): what about shape? +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Status* s); + // Create an operation suitable to use with the provided context. The operation // requires its type (e.g. "AddV2") to be set independently. TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx); @@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, void TF_DeleteAbstractTensor(TF_AbstractTensor*); // TF_OutputList holds the list of TF_AbstractTensor that results from executing -// an operation. -// It just lets us not specify the number of outputs of an operation -// beforehand. This forces a memory allocation in the runtime, which is bad, but -// it allows for generic code. -// TODO(aminim): the description above isn't clear with respect to -// TF_OutputListNumOutputs and the current eager implementation which requires -// the number of outputs to be set by the client. +// an operation, or provided to create a function. +// When executing an operation in an eager context, the expected number of +// outputs must be set beforehand with `TF_OutputListSetNumOutputs`. typedef struct TF_OutputList TF_OutputList; TF_OutputList* TF_NewOutputList(); void TF_DeleteOutputList(TF_OutputList* o); -void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*); +// Prepare tracing to the expected number of output for an operation. +void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*); +// Return the number of outputs in the list. int TF_OutputListNumOutputs(TF_OutputList* o); +// Return the `i`th output in the list. TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i); +// Append a tensor at the end of the output list, growing its size by one. +void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, + TF_Status*); // TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe // capture some inputs and then add a node in the graph. The output tensors are @@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_ExecutionContext* ctx, TF_Status* s); // Creates a new TF_AbstractFunction from the current tracing states in the -// context. The returned TF_GraphToFunction must be deleted by the client. +// context. The provided `ctx` is consumed by this API call and deleted. +// The returned TF_AbstractFunction must be deleted by the client, // TODO(aminim): clarify the contract on the state of the context after this // call. -TF_AbstractFunction* TF_ExecutionContextToFunction( - const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const TF_AbstractTensor* inputs, int num_outputs, - const TF_AbstractTensor* outputs, TF_Status* status); +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList*, TF_Status*); void TF_DeleteAbstractFunction(TF_AbstractFunction*); diff --git a/tensorflow/c/eager/c_api_unified_experimental_eager.cc b/tensorflow/c/eager/c_api_unified_experimental_eager.cc index 820c61445fb..cf8cf845834 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_eager.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_eager.cc @@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext { } } + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't add function parameter on an eager context."); + return nullptr; + } + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + TF_SetStatus(s, TF_INVALID_ARGUMENT, + "Can't use finalize function on an eager context."); + return nullptr; + } + void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override { auto* func = afunc->GetTfFunction(s); if (!func) { diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 36f8353894b..dd5a95b3526 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" @@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction { static constexpr AbstractFunctionKind kKind = kGraphFunc; }; -// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e. -// adding them to the graph. +// GraphContext wraps a TF_Graph modeling a single function and manages the +// "execution" of operation, i.e. adding them to the function. class GraphContext : public ExecutionContext { public: - GraphContext() - : ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {} + explicit GraphContext(const char* name) + : ExecutionContext(kKind), + graph_(new TF_Graph(), TF_DeleteGraph), + name_(name) {} AbstractOp* CreateOperation() override { // TODO(srbs): Should the lifetime of this op be tied to the context. @@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext { return; } auto* tf_opdesc = graph_op->op_.release(); + if (tf_opdesc == nullptr) { + TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete."); + return; + } for (int i = 0; i < num_inputs; ++i) { auto* graph_tensor = dyncast(inputs[i]); if (!graph_tensor) { @@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext { } } - TF_Function* ToFunction(const char* fn_name, int num_inputs, - const GraphTensor* inputs, int num_outputs, - const GraphTensor* outputs, TF_Status* status) const { - std::vector graph_inputs; - graph_inputs.resize(num_inputs); + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { + TF_OperationDescription* opdesc = + TF_NewOperation(graph_.get(), "Placeholder", + absl::StrCat("_input_", inputs_.size()).c_str()); + TF_SetAttrType(opdesc, "dtype", dtype); + auto* operation = TF_FinishOperation(opdesc, s); + if (!s->status.ok()) return nullptr; + + inputs_.push_back(TF_Output{operation, 0}); + return new GraphTensor(inputs_.back(), this); + } + + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + std::unique_ptr func(new GraphFunction); std::vector graph_outputs; - graph_outputs.resize(num_outputs); - for (int i = 0; i < num_inputs; i++) { - graph_inputs[i] = inputs[i].output; - } - for (int i = 0; i < num_outputs; i++) { - graph_outputs[i] = outputs[i].output; + graph_outputs.reserve(outputs->outputs.size()); + for (AbstractTensor* abstract_output : outputs->outputs) { + GraphTensor* output = dyncast(abstract_output); + if (!output) { + TF_SetStatus(s, TF_UNIMPLEMENTED, + "Returning a non-graph tensor from a function has not " + "been implemented yet."); + return nullptr; + } + graph_outputs.push_back(output->output); } - return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr, - graph_inputs.size(), graph_inputs.data(), - graph_outputs.size(), graph_outputs.data(), - nullptr, nullptr, fn_name, status); + func->func = TF_GraphToFunction( + graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), + graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); + if (TF_GetCode(s) != TF_OK) return nullptr; + return func.release(); } void RegisterFunction(AbstractFunction* func, TF_Status* s) override { @@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext { private: std::unique_ptr graph_; + std::vector inputs_; + const char* name_; }; -// Helper that converts the graph currently held in the context into a function. -static AbstractFunction* ExecutionContextToFunction( - const ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const AbstractTensor* inputs, int num_outputs, - const AbstractTensor* outputs, TF_Status* status) { - auto* graph_ctx = dyncast(fn_body); - if (graph_ctx == nullptr) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - "fn_body is not a TF_GraphContext."); - return nullptr; - } - auto* graph_inputs = dyncast(inputs); - if (!graph_inputs) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors."); - return nullptr; - } - auto* graph_outputs = dyncast(outputs); - if (!graph_outputs) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors."); - return nullptr; - } - GraphFunction* func = new GraphFunction; - func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs, - num_outputs, graph_outputs, status); - return func; +static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) { + return new GraphContext(name); } +// Register the tracing implemented in this file as the default tracing engine. +static bool register_tracing = [] { + RegisterTracingEngineFactory("graphdef", GraphTracingFactory); + SetDefaultTracingEngine("graphdef"); + return true; +}(); + } // namespace internal } // namespace tensorflow - -// ============================================================================= -// Public C API entry points -// These are only the entry points specific to the Graph API. -// ============================================================================= - -using tensorflow::internal::unwrap; - -TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) { - return wrap(new tensorflow::internal::GraphContext()); -} - -TF_AbstractFunction* TF_ExecutionContextToFunction( - const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs, - const TF_AbstractTensor* inputs, int num_outputs, - const TF_AbstractTensor* outputs, TF_Status* status) { - return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs, - unwrap(inputs), num_outputs, - unwrap(outputs), status)); -} diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index ab085a20ff0..49212a230ee 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace internal { @@ -148,6 +149,17 @@ struct ExecutionContext { // Creates an empty AbstractOperation suitable to use with this context. virtual AbstractOp* CreateOperation() = 0; + // Add a function parameter and return the corresponding tensor. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0; + + // Finalize this context and make a function out of it. The context is in a + // invalid state after this call and must be destroyed. + // This is only valid with an ExecutionContext obtained from a TracingContext, + // it'll always error out with an eager context. + virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0; + // Registers a functions with this context, after this the function is // available to be called/referenced by its name in this context. virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0; @@ -156,6 +168,11 @@ struct ExecutionContext { const ExecutionContextKind k; }; +typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*); +void SetDefaultTracingEngine(const char* name); +void RegisterTracingEngineFactory(const ::tensorflow::string& name, + FactoryFunction factory); + // Create utilities to wrap/unwrap: this convert from the C opaque types to the // C++ implementation, and back. #define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \ diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 170b82333d8..9776b4d13ed 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -29,7 +29,12 @@ using tensorflow::string; namespace tensorflow { namespace { -TEST(UnifedCAPI, TestBasicEager) { +class UnifiedCAPI : public ::testing::TestWithParam { + protected: + void SetUp() override { TF_SetTracingImplementation(GetParam()); } +}; + +TEST_P(UnifiedCAPI, TestBasicEager) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -81,33 +86,18 @@ TEST(UnifedCAPI, TestBasicEager) { TF_DeleteExecutionContext(ctx); } -TEST(UnifedCAPI, TestBasicGraph) { +TEST_P(UnifiedCAPI, TestBasicGraph) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + // Start a new function / execution context. + string fn_name = "double"; + TF_ExecutionContext* graph_ctx = + TF_CreateFunction(fn_name.c_str(), status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - // Add a placeholder to the graph. - auto* placeholder_op = TF_NewAbstractOp(graph_ctx); - TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); + auto* placeholder_t = + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build inputs and outputs. - TF_OutputList* placeholder_outputs = TF_NewOutputList(); - - // Execute. - TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs, - graph_ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs)); - TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0); - - // Delete placeholder op. - TF_DeleteAbstractOp(placeholder_op); // Build an abstract operation. auto* add_op = TF_NewAbstractOp(graph_ctx); @@ -123,16 +113,13 @@ TEST(UnifedCAPI, TestBasicGraph) { // Execute. TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0); // Clean up operation and inputs. TF_DeleteAbstractOp(add_op); - string fn_name = "double"; - TF_AbstractFunction* func = TF_ExecutionContextToFunction( - graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get()); - TF_DeleteAbstractTensor(placeholder_t); - TF_DeleteAbstractTensor(output_t); + TF_AbstractFunction* func = + TF_FinalizeFunction(graph_ctx, add_outputs, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build eager context. TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -173,18 +160,161 @@ TEST(UnifedCAPI, TestBasicGraph) { ASSERT_EQ(*f_value, 4.0); TF_DeleteOutputList(add_outputs); - TF_DeleteOutputList(placeholder_outputs); TF_DeleteAbstractOp(fn_op); TF_DeleteAbstractTensor(input_t); TF_DeleteAbstractTensor(final_result); TF_DeleteTensor(f_t); TF_DeleteAbstractFunction(func); - TF_DeleteExecutionContext(graph_ctx); TF_DeleteExecutionContext(eager_execution_ctx); } -TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { +TEST_P(UnifiedCAPI, TestMultiOutputGraph) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + + // Start a new function / execution context. + string fn_name = "two_adds"; + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Create a first "Add" computing `arg0 + arg1`. + TF_AbstractTensor* add_output1; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add1", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg0, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + // Extract the resulting tensor. + add_output1 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // Same with a second "Add" computing `arg1 + arg1`. + TF_AbstractTensor* add_output2; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add2", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg1, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + // Extract the resulting tensor. + add_output2 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // Finalize the function by providing the returned values. + TF_AbstractFunction* func; + { + // We want to return the output of both add operations, create a new list + // and populate it. + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListPushBack(func_outputs, add_output1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_OutputListPushBack(func_outputs, add_output2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + func = TF_FinalizeFunction(graph_ctx, func_outputs, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteOutputList(func_outputs); + } + + /** + * We traced so far this function: + * + * def two_adds(a, b): + * my_add1 = a + b + * my_add2 = b + b + * return my_add1, my_add2 + * + * Now we will execute this function with an eager context: + * + * output1, output2 = two_adds(2.0, 3.0) + * + * and check that we got 5.0 and 6.0 as results. + */ + + // Build eager context. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* eager_execution_ctx = + TF_NewEagerExecutionContext(opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TFE_DeleteContextOptions(opts); + + TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build the abstract op to run the function. + TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); + TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build two abstract input tensors as function arguments. + std::vector func_args; + { + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(eager_execution_ctx); + TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + input_eager = TestScalarTensorHandle(eager_ctx, 3.0f); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + } + + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(func_outputs, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs, + eager_execution_ctx, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(fn_op); + for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t); + + ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs)); + float results[2]; + for (int idx = 0; idx < 2; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + results[idx] = *static_cast(TF_TensorData(f_t)); + TF_DeleteTensor(f_t); + } + ASSERT_EQ(results[0], 5.0); + ASSERT_EQ(results[1], 6.0); + + for (int idx = 0; idx < 2; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TF_DeleteAbstractTensor(result); + } + TF_DeleteOutputList(func_outputs); + TF_DeleteExecutionContext(eager_execution_ctx); + TF_DeleteAbstractFunction(func); +} + +TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -192,18 +322,15 @@ TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContextOptions(opts); - TF_AbstractFunction* func = TF_ExecutionContextToFunction( - ctx, nullptr, 0, nullptr, 0, nullptr, status.get()); + TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get()); ASSERT_EQ(nullptr, func); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); - - TF_DeleteExecutionContext(ctx); } -TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -221,10 +348,10 @@ TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -242,7 +369,7 @@ TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) { +TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { // Build an Eager context. std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -272,7 +399,8 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build a Graph context. - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Execute eager op using graph context. @@ -288,10 +416,11 @@ TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) { +TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. @@ -348,5 +477,7 @@ TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) { TF_DeleteExecutionContext(eager_execution_ctx); } +INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef")); + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index be0aad31a35..d21ab45e579 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -59,6 +59,20 @@ class AbstractContextInterface { virtual AbstractTensorInterface* CreateTensor( DataType dtype, absl::Span dim_sizes) = 0; + typedef void (*MemoryReleaser)(void* data, size_t len, void* arg); + + // Create a tensor instance from the given data buffer and description. + // `memory_releaser` will be called on destruction, and it's responsible for + // cleaning up the underlying buffer. `convert_string` indicates whether it + // has to handle tstring conversion. Expected to be removed once tstring + // migration is done. + virtual AbstractTensorInterface* CreateTensor(DataType dtype, + const int64_t* dims, + int num_dims, void* data, + size_t len, bool convert_string, + MemoryReleaser memory_releaser, + void* memory_releaser_arg) = 0; + // Create a handle to wrap and manage a Tensor virtual AbstractTensorHandleInterface* CreateLocalHandle( AbstractTensorInterface* t) = 0; @@ -81,6 +95,12 @@ class AbstractContextInterface { virtual void ClearCachesAndThreadExecutors() = 0; + // Initialize the step resource container for a training step. This is used + // in current TF runtime. For tfrt, it is used by fallback op handler. + virtual void StartStep() = 0; + // Destroy the step resource container for a training step. + virtual void EndStep() = 0; + protected: virtual ~AbstractContextInterface() {} }; diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index f4dbcc6cead..3b2640e14d1 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -27,6 +27,7 @@ cc_library( name = "parallel_device", srcs = [":sources"], hdrs = [":headers"], + visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/c:c_api", "//tensorflow/c/eager:c_api", @@ -43,6 +44,7 @@ tf_cc_test( srcs = ["parallel_device_test.cc"], deps = [ ":parallel_device", + ":parallel_device_ops", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", @@ -52,3 +54,19 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +# Note: ParallelDevice-specific ops are experimental and not currently linked in +# to TensorFlow by default, just used in a few tests. +filegroup( + name = "parallel_device_ops_srcs", + srcs = ["parallel_device_ops.cc"], + visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"], +) + +cc_library( + name = "parallel_device_ops", + srcs = [":parallel_device_ops_srcs"], + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/core:framework"], + alwayslink = 1, +) diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index e6846809fcf..27c2699c4c2 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -92,6 +92,10 @@ class ParallelDevice { TFE_TensorHandle* tensor, TF_Status* status) const; + // A parallel tensor with scalar integers numbering component devices. + std::unique_ptr DeviceIDs(TFE_Context* context, + TF_Status* status) const; + // Takes a description of a single operation being executed on the // ParallelDevice, and in turn runs one operation per component device with // its corresponding inputs from the input ParallelTensors (or @@ -208,6 +212,46 @@ std::unique_ptr ParallelDevice::CopyToParallelDevice( status); } +std::unique_ptr ParallelDevice::DeviceIDs( + TFE_Context* context, TF_Status* status) const { + // TODO(allenl): We could cache DeviceIDs (keyed by context). + std::vector components; + components.reserve(underlying_devices_.size()); + for (int device_index = 0; device_index < underlying_devices_.size(); + ++device_index) { + int64_t* device_id = new int64_t; + *device_id = device_index; + std::unique_ptr tensor( + TF_NewTensor( + TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id, + sizeof(int64_t), + [](void* data, size_t, void* arg) { + delete reinterpret_cast(data); + }, + nullptr), + TF_DeleteTensor); + // TODO(allenl): Here and when executing regular operations, we could hold + // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing + // device names repeatedly. + OpPtr const_op(TFE_NewOp(context, "Const", status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64); + TFE_TensorHandle* device_handle; + int num_outputs = 1; + TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + components.emplace_back(device_handle); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + return ParallelTensor::FromTensorHandles(*this, std::move(components), + status); +} + absl::optional> ParallelDevice::Execute( TFE_Context* context, std::vector inputs, const char* operation_name, const TFE_OpAttrs* attributes, @@ -282,6 +326,13 @@ absl::optional> ParallelDevice::Execute( } result.emplace(std::move(outputs)); return result; + } else if (operation_name == std::string("DeviceID")) { + std::vector result_content; + result_content.reserve(1); + result_content.push_back(DeviceIDs(context, status)); + if (TF_GetCode(status) != TF_OK) return result; + result.emplace(std::move(result_content)); + return result; } absl::optional>> maybe_parallel_results( diff --git a/tensorflow/c/eager/parallel_device/parallel_device_ops.cc b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc new file mode 100644 index 00000000000..1decffca047 --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_ops.cc @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +// TODO(allenl): Figure out if we need this op, and if so whether we should move +// it to core TF. Right now the eager C API does some checking of op +// registrations before calling into custom devices, but we may be able to avoid +// that. +REGISTER_OP("DeviceID") + .Output("device_id: int64") + .SetIsStateful() + .SetShapeFn(tensorflow::shape_inference::ScalarShape); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index 9b0613b0391..fdc140407df 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, } // Assert that `handle` is equal to `expected_value`. -void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) { +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); std::unique_ptr value_zero( TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - ASSERT_EQ(expected_value, - *static_cast(TF_TensorData(value_zero.get()))); + EXPECT_EQ(expected_value, + *static_cast(TF_TensorData(value_zero.get()))); } template @@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, ExtractPerDeviceValues(context, read.get(), &components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(components[0].get(), 20.); - AssertScalarFloatEq(components[1].get(), 20.); + ExpectScalarEq(components[0].get(), 20.); + ExpectScalarEq(components[1].get(), 20.); std::string first_device = TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); @@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, ExtractPerDeviceValues(context, read.get(), &components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(components[0].get(), 23.); - AssertScalarFloatEq(components[1].get(), 18.); + ExpectScalarEq(components[0].get(), 23.); + ExpectScalarEq(components[1].get(), 18.); std::string first_device = TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); @@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); ASSERT_EQ(underlying_devices[1], second_device); } + // Compute the device ID twice and verify the result + for (int i = 0; i < 2; ++i) { + std::unique_ptr op( + TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetDevice(op.get(), device_name, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::array components; + ExtractPerDeviceValues(context, result_handle, &components, status.get()); + TFE_DeleteTensorHandle(result_handle); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + ExpectScalarEq(components[0].get(), 0); + ExpectScalarEq(components[1].get(), 1); + std::string first_device = + TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); + ASSERT_EQ(underlying_devices[0], first_device); + std::string second_device = + TFE_TensorHandleBackingDeviceName(components[1].get(), status.get()); + ASSERT_EQ(underlying_devices[1], second_device); + } } TEST(PARALLEL_DEVICE, TestBasicCPU) { @@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); // The value of the original tensor is replicated on each device. - AssertScalarFloatEq(components[0].get(), 3.); - AssertScalarFloatEq(components[1].get(), 3.); + ExpectScalarEq(components[0].get(), 3.); + ExpectScalarEq(components[1].get(), 3.); // Verify that the mirrors are placed on the component devices. std::string first_device = @@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { &second_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(second_components[1].get(), 9.); + ExpectScalarEq(second_components[1].get(), 9.); // Verify that the mirrors are placed on the component devices. std::string first_device = TFE_TensorHandleBackingDeviceName( @@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { std::array first_components; ExtractPerDeviceValues(context.get(), second_components[0].get(), &first_components, status.get()); - AssertScalarFloatEq(first_components[0].get(), 3.); - AssertScalarFloatEq(first_components[1].get(), 6.); + ExpectScalarEq(first_components[0].get(), 3.); + ExpectScalarEq(first_components[1].get(), 6.); first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(), status.get()); @@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) { ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(result_components[0].get(), 3.); - AssertScalarFloatEq(result_components[1].get(), 3.); + ExpectScalarEq(result_components[0].get(), 3.); + ExpectScalarEq(result_components[1].get(), 3.); } void RegisterCollectiveMulFunction(TFE_Context* context, @@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) { ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - AssertScalarFloatEq(result_components[0].get(), 7. * 9.); - AssertScalarFloatEq(result_components[1].get(), 7. * 9.); + ExpectScalarEq(result_components[0].get(), 7. * 9.); + ExpectScalarEq(result_components[1].get(), 7. * 9.); std::string first_device = TFE_TensorHandleBackingDeviceName( result_components[0].get(), status.get()); diff --git a/tensorflow/c/eager/tfe_op_attrs_internal.h b/tensorflow/c/eager/tfe_op_attrs_internal.h index 935d7d520e5..0287502dea6 100644 --- a/tensorflow/c/eager/tfe_op_attrs_internal.h +++ b/tensorflow/c/eager/tfe_op_attrs_internal.h @@ -15,33 +15,21 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ -#include -#include -#include -#include -#include -#include -#include - +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/attr_value.pb.h" // An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways // that sometimes do not require serialization. +typedef struct TFE_OpAttrs TFE_OpAttrs; + typedef struct TFE_Context TFE_Context; typedef struct TFE_Op TFE_Op; -struct TFE_OpAttrs { - explicit TFE_OpAttrs() : attributes(nullptr) {} - - explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value) - : attributes(value) {} - - const tensorflow::AttrBuilder* attributes; -}; - namespace tensorflow { +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs); + // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc b/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc index 53e247cd038..8ee47da01dd 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc @@ -85,17 +85,36 @@ class ModularFileSystemTest : public ::testing::TestWithParam { const std::string test_name = tensorflow::str_util::StringReplace( ::testing::UnitTest::GetInstance()->current_test_info()->name(), "/", "_", /*replace_all=*/true); - root_dir_ = tensorflow::io::JoinPath( - ::testing::TempDir(), - tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name)); + if (!cloud_path_.empty()) { + // We have to join path for non-local filesystem manually to make sure + // that this test will run on Windows since `tensorflow::io::JoinPath` + // behaves differently on Windows. `tmp_dir` should be something like + // `path/to/tmp/dir/`. After joining path, we will have + // /path/to/tmp/dir/tf_fs_rng_name/` + root_dir_ = tensorflow::strings::StrCat( + "/", tmp_dir_, + tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name), "/"); + } else { + root_dir_ = tensorflow::io::JoinPath( + tmp_dir_, + tensorflow::strings::StrCat("tf_fs_", rng_val_, "_", test_name)); + } + if (!GetParam().empty()) { + root_dir_ = tensorflow::strings::StrCat(GetParam(), "://", cloud_path_, + root_dir_); + } env_ = Env::Default(); } void SetUp() override { - if (mkdir(root_dir_.c_str(), 0755) != 0) { - int error_code = errno; - GTEST_SKIP() << "Cannot create working directory: " - << tensorflow::IOError(root_dir_, error_code); + FileSystem* fs = nullptr; + Status s = env_->GetFileSystemForFile(root_dir_, &fs); + if (fs == nullptr || !s.ok()) + GTEST_SKIP() << "No filesystem registered: " << s; + + s = fs->CreateDir(root_dir_); + if (!s.ok()) { + GTEST_SKIP() << "Cannot create working directory: " << s; } } @@ -115,9 +134,10 @@ class ModularFileSystemTest : public ::testing::TestWithParam { std::string GetURIForPath(StringPiece path) { const std::string translated_name = tensorflow::io::JoinPath(root_dir_, path); - if (GetParam().empty()) return translated_name; - - return tensorflow::strings::StrCat(GetParam(), "://", translated_name); + // We have already checked `GetParam().empty()` in + // `ModularFileSystemTest()`. root_dir_ should contain `GetParam() + "://"` + // if it isn't empty. + return translated_name; } // Converts absolute paths to paths relative to root_dir_. @@ -133,15 +153,28 @@ class ModularFileSystemTest : public ::testing::TestWithParam { rng_val_ = distribution(gen); } + static void SetCloudPath(const std::string& cloud_path) { + cloud_path_ = cloud_path; + if (cloud_path_.back() == '/') cloud_path_.pop_back(); + } + + static void SetTmpDir(const std::string& tmp_dir) { + tmp_dir_ = tmp_dir.empty() ? ::testing::TempDir() : tmp_dir; + } + protected: Env* env_; private: std::string root_dir_; static int rng_val_; + static std::string cloud_path_; + static std::string tmp_dir_; }; int ModularFileSystemTest::rng_val_; +std::string ModularFileSystemTest::cloud_path_; +std::string ModularFileSystemTest::tmp_dir_; // As some of the implementations might be missing, the tests should still pass // if the returned `Status` signals the unimplemented state. @@ -1729,6 +1762,20 @@ static bool GetURIScheme(const std::string& scheme) { return true; } +// This function is used for cloud filesystem +// `S3` and `GCS` require the `root_dir_` to have bucket name +// `HDFS` requires the `root_dir` to have namenode +// `root_dir_ = scheme + "://" cloud_path_ + root_dir_` +static bool SetCloudPath(const std::string& cloud_path_) { + ModularFileSystemTest::SetCloudPath(cloud_path_); + return true; +} + +static bool SetTmpDir(const std::string& tmp_dir_) { + ModularFileSystemTest::SetTmpDir(tmp_dir_); + return true; +} + } // namespace } // namespace tensorflow @@ -1741,7 +1788,12 @@ GTEST_API_ int main(int argc, char** argv) { tensorflow::Flag("dso", tensorflow::LoadDSO, "", "Path to shared object to load"), tensorflow::Flag("scheme", tensorflow::GetURIScheme, "", - "URI scheme to test")}; + "URI scheme to test"), + tensorflow::Flag("cloud_path", tensorflow::SetCloudPath, "", + "Path for cloud filesystem (namenode for hdfs, " + "bucketname for s3/gcs)"), + tensorflow::Flag("tmp_dir", tensorflow::SetTmpDir, "", + "Temporary directory to store test data.")}; if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { std::cout << tensorflow::Flags::Usage(argv[0], flag_list); return -1; diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 7a694f4f803..5c51e26f925 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -31,9 +31,6 @@ cc_library( "//tensorflow/c/experimental/saved_model/public:concrete_function.h", ], copts = tf_copts(), - # TODO(bmzhao): Remove this as we refactor C API to granular targets, - # so that we can depend on c/eager/c_api_unified_experimental.h. - features = ["-layering_check"], visibility = [ "//tensorflow/c/experimental/saved_model/public:__pkg__", ], @@ -41,6 +38,8 @@ cc_library( ":concrete_function_type", ":function_metadata", ":function_metadata_type", + ":tensorhandle_list", + ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", @@ -160,6 +159,38 @@ cc_library( ], ) +cc_library( + name = "tensorhandle_list", + srcs = [ + "tensorhandle_list.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":tensorhandle_list_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:tfe_tensorhandle_internal", + ], +) + +cc_library( + name = "tensorhandle_list_type", + hdrs = [ + "tensorhandle_list_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/eager:tensor_handle_interface", + ], +) + tf_cc_test( name = "saved_model_api_test", size = "small", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 4884f9e2e97..dd54416ddf9 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" +#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" extern "C" { @@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { &tensorflow::unwrap(func)->GetFunctionMetadata())); } -TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) { - // TODO(bmzhao): Refactor TF_OutputList struct definition into a separate - // internal header, and implement this function. - return nullptr; +const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( + TF_ConcreteFunction* func) { + return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); } TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index cce1b27d9ad..629610dbe29 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -66,7 +66,7 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; } TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, - char* function_path, + const char* function_path, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = @@ -79,7 +79,7 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, } TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, char* signature_def_key, TF_Status* status) { + TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = model->saved_model->GetSignatureDefFunction(signature_def_key, &result); diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc new file mode 100644 index 00000000000..7d018658101 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" + +#include + +#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" + +extern "C" { + +size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) { + return tensorflow::unwrap(list)->size(); +} + +TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list, + int i) { + return tensorflow::wrap((*tensorflow::unwrap(list))[i]); +} + + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h new file mode 100644 index 00000000000..8cbec2806a8 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/tensor_handle_interface.h" + +// Internal structures used by the SavedModel C API. These are likely to +// change and should not be depended on. + +typedef struct TF_TensorHandleList TF_TensorHandleList; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS( + std::vector, + TF_TensorHandleList) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index af65e05e7f6..0cfa0a2c005 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -24,6 +24,7 @@ exports_files( "concrete_function_list.h", "function_metadata.h", "saved_model_api.h", + "tensorhandle_list.h", ], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], ) @@ -39,6 +40,7 @@ cc_library( ":concrete_function_list", ":function_metadata", ":saved_model_api", + ":tensorhandle_list", ], ) @@ -61,3 +63,8 @@ alias( name = "saved_model_api", actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", ) + +alias( + name = "tensorhandle_list", + actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list", +) diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h index 30f533f140a..aae95a5477c 100644 --- a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" // IWYU pragma: end_exports #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 351d8daed8e..2a87214270c 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ #include "tensorflow/c/c_api_macros.h" -#include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" #ifdef __cplusplus extern "C" { @@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( TF_ConcreteFunction* func); // Returns a list of TensorHandles implicitly captured by this function. -TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures( +TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( TF_ConcreteFunction* func); // Returns a TFE_Op suitable for executing this function. diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function_list.h b/tensorflow/c/experimental/saved_model/public/concrete_function_list.h index 7add847259c..e35546751f1 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function_list.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function_list.h @@ -21,19 +21,27 @@ limitations under the License. #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + // An opaque type that is acts like a list of TF_ConcreteFunction pointers. typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList; // Returns the size of `list`. -TF_CAPI_EXPORT size_t -TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list); +TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize( + TF_ConcreteFunctionList* list); // Returns the `i`th TF_ConcreteFunction in the list. -TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet( +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet( TF_ConcreteFunctionList* list, int i); // Deletes `list`. -TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList( +TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList( TF_ConcreteFunctionList* list); +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/c/experimental/saved_model/public/saved_model_api.h b/tensorflow/c/experimental/saved_model/public/saved_model_api.h index ad381937e3c..875167bec63 100644 --- a/tensorflow/c/experimental/saved_model/public/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/saved_model_api.h @@ -80,7 +80,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model); // "conceptually" bound to `model`. Once `model` is deleted, all // `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( - TF_SavedModel* model, char* function_path, TF_Status* status); + TF_SavedModel* model, const char* function_path, TF_Status* status); // Retrieve a function from the TF SavedModel via a SignatureDef key. // @@ -94,7 +94,7 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( // TF_ConcreteFunction instance. Once `model` is deleted, all // `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, char* signature_def_key, TF_Status* status); + TF_SavedModel* model, const char* signature_def_key, TF_Status* status); // Returns a list of all ConcreteFunctions stored in this SavedModel. // The lifetime of the returned list is bound to `model`. diff --git a/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h b/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h new file mode 100644 index 00000000000..a1e88db3474 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that is acts like a list of TF_ConcreteFunction pointers. +typedef struct TF_TensorHandleList TF_TensorHandleList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize( + const TF_TensorHandleList* list); + +// Returns the `i`th TFE_TensorHandle in the list. +TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet( + const TF_TensorHandleList* list, int i); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index e8cb40f153b..e1fad8e697a 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -178,7 +178,7 @@ cc_library_with_android_deps( name = "ops", srcs = ["framework/ops.cc"], hdrs = ["framework/ops.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -197,7 +197,7 @@ cc_library_with_android_deps( "framework/scope_internal.h", ], hdrs = ["framework/scope.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], common_deps = [ ":ops", ], @@ -237,7 +237,7 @@ cc_library_with_android_deps( name = "client_session", srcs = ["client/client_session.cc"], hdrs = ["client/client_session.h"], - android_deps = ["//tensorflow/core:android_tensorflow_lib"], + android_deps = ["//tensorflow/core:portable_tensorflow_lib"], common_deps = [ ":ops", ":scope", @@ -275,7 +275,7 @@ cc_library_with_android_deps( srcs = ["ops/const_op.cc"], hdrs = ["ops/const_op.h"], android_deps = [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], common_deps = [ ":ops", @@ -304,7 +304,7 @@ cc_library_with_android_deps( srcs = ["ops/while_loop.cc"], hdrs = ["ops/while_loop.h"], android_deps = [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], common_deps = [ ":cc_ops", diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD new file mode 100644 index 00000000000..045d4e6cd97 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/BUILD @@ -0,0 +1,78 @@ +# Experimental C++ APIs for TensorFlow. +# New TF C++ APIs under the tensorflow::cc namespace aim to guarantee ABI stability. +# Users are expected to compile against public c++ headers, and link against +# libtensorflow (https://www.tensorflow.org/install/lang_c). +# We aim to achieve ABI stability in new C++ APIs by only using types +# on the API surface that: +# 1. Have a header-only implementation +# 2. Are std:: types +# 3. Wrap an opaque C type + +package( + # This is intentionally public + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "runtime", + hdrs = [ + "runtime.h", + ], + deps = [ + ":status", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) + +cc_library( + name = "runtime_builder", + hdrs = [ + "runtime_builder.h", + ], + deps = [ + ":runtime", + ":status", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) + +cc_library( + name = "status", + hdrs = [ + "status.h", + ], + deps = [ + "//tensorflow/c:tf_status", + ], +) + +cc_library( + name = "tensor", + hdrs = [ + "tensor.h", + ], + deps = [ + ":status", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_tensor", + ], +) + +cc_library( + name = "tensorhandle", + hdrs = [ + "tensorhandle.h", + ], + deps = [ + ":runtime", + ":status", + ":tensor", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + ], +) diff --git a/tensorflow/cc/experimental/base/public/runtime.h b/tensorflow/cc/experimental/base/public/runtime.h new file mode 100644 index 00000000000..711a38c233a --- /dev/null +++ b/tensorflow/cc/experimental/base/public/runtime.h @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ + +#include + +#include "tensorflow/c/eager/c_api_experimental.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Runtime represents an opaque instance of a Tensorflow runtime, with its own +// resources, threadpools, etc. Clients are expected to construct a Runtime +// object through tensorflow::cc::RuntimeBuilder::Build, after setting any +// relevant configuration options. Many Tensorflow functions take a reference to +// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and +// may have different implementations depending on the runtime. For many of +// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the +// Runtime must outlive these objects. +class Runtime { + public: + // Runtime is movable, but not copyable. + Runtime(Runtime&&) = default; + Runtime& operator=(Runtime&&) = default; + + private: + friend class RuntimeBuilder; + friend class SavedModelAPI; + friend class TensorHandle; + + // Wraps a TFE_Context. Takes ownership of ctx. + explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} + + // Deletes the currently wrapped TFE_Context, swaps it with ctx, + // and takes ownership of ctx. + void Reset(TFE_Context* ctx) { ctx_.reset(ctx); } + + // Returns the TFE_Context that this object wraps. This object + // retains ownership of the pointer. + TFE_Context* GetTFEContext() const { return ctx_.get(); } + + // Runtime is not copyable + Runtime(const Runtime&) = delete; + Runtime& operator=(const Runtime&) = delete; + + struct TFEContextDeleter { + void operator()(TFE_Context* p) const { TFE_DeleteContext(p); } + }; + std::unique_ptr ctx_; +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ diff --git a/tensorflow/cc/experimental/base/public/runtime_builder.h b/tensorflow/cc/experimental/base/public/runtime_builder.h new file mode 100644 index 00000000000..737e06cb2c6 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/runtime_builder.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. +// Use this to set configuration options, like threadpool size, etc. +class RuntimeBuilder { + public: + RuntimeBuilder() : options_(TFE_NewContextOptions()) {} + + // If `use_tfrt` is true, we will use the new Tensorflow Runtime + // (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as + // our runtime implementation. + RuntimeBuilder& SetUseTFRT(bool use_tfrt); + + // Build a Tensorflow Runtime. + // + // Params: + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // unique_ptr. + std::unique_ptr Build(Status* status); + + // RuntimeBuilder is movable, but not copyable. + RuntimeBuilder(RuntimeBuilder&&) = default; + RuntimeBuilder& operator=(RuntimeBuilder&&) = default; + + private: + // RuntimeBuilder is not copyable + RuntimeBuilder(const RuntimeBuilder&) = delete; + RuntimeBuilder& operator=(const RuntimeBuilder&) = delete; + + struct TFEContextOptionsDeleter { + void operator()(TFE_ContextOptions* p) const { + TFE_DeleteContextOptions(p); + } + }; + std::unique_ptr options_; +}; + +inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) { + TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt); + return *this; +} + +inline std::unique_ptr RuntimeBuilder::Build(Status* status) { + TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new Runtime(result)); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ diff --git a/tensorflow/cc/experimental/base/public/status.h b/tensorflow/cc/experimental/base/public/status.h new file mode 100644 index 00000000000..98c8cf6ced2 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/status.h @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ + +#include +#include + +#include "tensorflow/c/tf_status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Status is a wrapper around an error code and an optional error message. +// The set of error codes are defined here: +// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60 +// Many Tensorflow APIs return a Status, or take a Status as an out parameter. +// Clients should check for status.ok() after calling these APIs, and either +// handle or propagate the error appropriately. +// TODO(bmzhao): Add a detailed code example before moving out of experimental. +class Status { + public: + // Create a success status + Status() : status_(TF_NewStatus()) {} + + // Return the status code + TF_Code code() const; + + // Returns the error message in Status. + std::string message() const; + + // Returns the error message in Status. + bool ok() const; + + // Record in Status. Any previous information is lost. + // A common use is to clear a status: SetStatus(TF_OK, ""); + void SetStatus(TF_Code code, const std::string& msg); + + // Status is movable, but not copyable. + Status(Status&&) = default; + Status& operator=(Status&&) = default; + + private: + friend class RuntimeBuilder; + friend class Runtime; + friend class SavedModelAPI; + friend class TensorHandle; + + // Wraps a TF_Status*, and takes ownership of it. + explicit Status(TF_Status* status) : status_(status) {} + + // Status is not copyable + Status(const Status&) = delete; + Status& operator=(const Status&) = delete; + + // Returns the TF_Status that this object wraps. This object + // retains ownership of the pointer. + TF_Status* GetTFStatus() const { return status_.get(); } + + struct TFStatusDeleter { + void operator()(TF_Status* p) const { TF_DeleteStatus(p); } + }; + std::unique_ptr status_; +}; + +inline TF_Code Status::code() const { return TF_GetCode(status_.get()); } + +inline std::string Status::message() const { + return std::string(TF_Message(status_.get())); +} + +inline bool Status::ok() const { return code() == TF_OK; } + +inline void Status::SetStatus(TF_Code code, const std::string& msg) { + TF_SetStatus(status_.get(), code, msg.c_str()); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ diff --git a/tensorflow/cc/experimental/base/public/tensor.h b/tensorflow/cc/experimental/base/public/tensor.h new file mode 100644 index 00000000000..fc447262ce1 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/tensor.h @@ -0,0 +1,175 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ + +#include +#include + +#include +#include +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/cc/experimental/base/public/status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Tensor represents an n-dimensional array of values. +class Tensor { + public: + using DeleterCallback = std::function; + + // Constructs a Tensor from user provided buffer. + // + // Params: + // dtype - The dtype of the tensor's data. + // shape - A shape vector, where each element corresponds to the size of + // the tensor's corresponding dimension. + // data - Pointer to a buffer of memory to construct a Tensor out of. + // len - The length (in bytes) of `data` + // deleter - A std::function to be called when the Tensor no longer needs the + // memory in `data`. This can be used to free `data`, or + // perhaps decrement a refcount associated with `data`, etc. + // status - Set to OK on success and an error on failure. + // Returns: + // If an error occurred, status->ok() will be false, and the returned + // Tensor must not be used. + // TODO(bmzhao): Add Runtime as an argument to this function so we can swap to + // a TFRT backed tensor. + // TODO(bmzhao): Add benchmarks on overhead for this function; we can + // consider using int64_t* + length rather than vector. + static Tensor FromBuffer(TF_DataType dtype, const std::vector& shape, + void* data, size_t len, DeleterCallback deleter, + Status* status); + + // TODO(bmzhao): In the case we construct a tensor from non-owned memory, + // we should offer a way to deep copy the tensor into a new tensor, which + // owns the underlying memory. This could be a .deepcopy()/clone() method. + + // TODO(bmzhao): In the future, we want to relax the non-copyability + // constraint. To do so, we can add a C API function that acts like + // CopyFrom: + // https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311 + + // Tensor is movable, but not copyable + Tensor(Tensor&&) = default; + Tensor& operator=(Tensor&&) = default; + + // Returns the number of dimensions in the tensor. Can be -1, which represents + // unknown rank. + int dims() const; + + // Returns the number of elements in in demension `d`. + // REQUIRES: `0 <= d < dims()` + int64_t dim_size(int d) const; + + // Returns a pointer to the underlying data buffer. + void* data() const; + + // Returns the data type of the tensor. + TF_DataType dtype() const; + + // Returns the number of elements in the tensor. For a tensor with a partially + // defined shape, -1 means not fully defined. + int64_t num_elements() const; + + // Returns the size of the underlying data in bytes. + size_t num_bytes() const; + + private: + friend class TensorHandle; + friend class Runtime; + + // Wraps a TF_Tensor. Takes ownership of handle. + explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {} + + // Tensor is not copyable + Tensor(const Tensor&) = delete; + Tensor& operator=(const Tensor&) = delete; + + // Returns the underlying TF_Tensor that this object wraps. + // This object retains ownership of the pointer. + TF_Tensor* GetTFTensor() const { return tensor_.get(); } + + struct DeleterStruct { + std::function deleter; + }; + + static void DeleterFunction(void* memory, size_t len, void* deleter_struct) { + DeleterStruct* deleter = reinterpret_cast(deleter_struct); + deleter->deleter(memory, len); + delete deleter; + } + + struct TFTensorDeleter { + void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } + }; + std::unique_ptr tensor_; +}; + +inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); } + +inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); } + +inline int64_t Tensor::dim_size(int d) const { + return TF_Dim(tensor_.get(), d); +} + +inline TF_DataType Tensor::dtype() const { + return TF_TensorType(tensor_.get()); +} + +inline int64_t Tensor::num_elements() const { + return TF_TensorElementCount(tensor_.get()); +} + +inline size_t Tensor::num_bytes() const { + return TF_TensorByteSize(tensor_.get()); +} + +inline Tensor Tensor::FromBuffer(TF_DataType dtype, + const std::vector& shape, void* data, + size_t len, DeleterCallback deleter, + Status* status) { + // Credit to apassos@ for this technique: + // Despite the fact that our API takes a std::function deleter, we are able + // to maintain ABI stability because: + // 1. Only a function pointer is sent across the C API (&DeleterFunction) + // 2. DeleterFunction is defined in the same build artifact that constructed + // the std::function (so there isn't confusion about std::function ABI). + // Note that 2. is satisifed by the fact that this is a header-only API, where + // the function implementations are inline. + + DeleterStruct* deleter_struct = new DeleterStruct{deleter}; + TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len, + &DeleterFunction, deleter_struct); + if (tensor == nullptr) { + status->SetStatus(TF_INVALID_ARGUMENT, + "Failed to create tensor for input buffer"); + return Tensor(nullptr); + } + return Tensor(tensor); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ diff --git a/tensorflow/cc/experimental/base/public/tensorhandle.h b/tensorflow/cc/experimental/base/public/tensorhandle.h new file mode 100644 index 00000000000..99453ee7ea8 --- /dev/null +++ b/tensorflow/cc/experimental/base/public/tensorhandle.h @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ + +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// An opaque representation of a tensor computed/managed by the Tensorflow +// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer +// to tensors placed in memory of different devices or remote address spaces. +// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created +// from it. +class TensorHandle { + public: + // Unwraps a Tensor from the given TensorHandle. If an error occurred, + // status->ok() will be false, and the returned Tensor must not be used. + Tensor Resolve(Status* status); + + // Constructs a TensorHandle from a Tensor. If an error occurred, + // status->ok() will be false, and the returned TensorHandle must not be used. + static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime, + Status* status); + + // TensorHandle is movable, and not copyable + TensorHandle(TensorHandle&&) = default; + TensorHandle& operator=(TensorHandle&&) = default; + + private: + // Wraps a TFE_TensorHandle. Takes ownership of handle. + explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {} + + // TensorHandle is not copyable + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + // Returns the underlying TFE_TensorHandle that this object wraps. + // This object retains ownership of the pointer. + TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); } + + // Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle, + // and takes ownership of handle. + void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); } + + struct TFETensorHandleDeleter { + void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); } + }; + std::unique_ptr handle_; +}; + +inline Tensor TensorHandle::Resolve(Status* status) { + TF_Tensor* tensor = + TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus()); + if (!status->ok()) { + return Tensor(nullptr); + } + return Tensor(tensor); +} + +inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor, + const Runtime& runtime, + Status* status) { + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor( + runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus()); + if (!status->ok()) { + return TensorHandle(nullptr); + } + return TensorHandle(tensor_handle); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD new file mode 100644 index 00000000000..f449d618f72 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/BUILD @@ -0,0 +1,50 @@ +# Tests for the C++ header-only base types. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tensor_types_test_util", + testonly = True, + hdrs = ["tensor_types_test_util.h"], + deps = [ + "//tensorflow/c:tf_datatype", + ], +) + +tf_cc_test( + name = "tensor_test", + srcs = [ + "tensor_test.cc", + ], + deps = [ + ":tensor_types_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/experimental/base/public:tensor", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "tensorhandle_test", + srcs = [ + "tensorhandle_test.cc", + ], + deps = [ + ":tensor_types_test_util", + "//tensorflow/c:tf_datatype", + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:runtime_builder", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/experimental/base/public:tensor", + "//tensorflow/cc/experimental/base/public:tensorhandle", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/cc/experimental/base/tests/tensor_test.cc b/tensorflow/cc/experimental/base/tests/tensor_test.cc new file mode 100644 index 00000000000..33f9ab637e8 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensor_test.cc @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/experimental/base/public/tensor.h" + +#include +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; + +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; + +template +class ConstructScalarTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes); + +// This test constructs a scalar tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + typename TypeParam::type value = 42; + Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 0); + EXPECT_EQ(tensor.dtype(), dtype); + EXPECT_EQ(*reinterpret_cast(tensor.data()), 42); + EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), 1); +} + +template +class Construct1DTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes); + +// This test constructs a 1D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 1 vector. + std::vector shape; + shape.push_back(value.size()); + + Tensor tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 1); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +template +class Construct2DTensorTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes); + +// This test constructs a 2D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) { + Status status; + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 2 vector with shape 2 x 3. + std::vector shape({2, 3}); + + Tensor tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 2); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +TEST(CPPTensorAPI, ConstructTensorFromBuffer) { + bool done = false; + Status status; + std::vector data_vector({12, 14, 20, 18, 39, 42, 100}); + { + // data_vector is a rank 1 tensor. + std::vector shape; + shape.push_back(data_vector.size()); + + Tensor::DeleterCallback callback = [&done](void* data, size_t len) { + done = true; + }; + + Tensor tensor = + Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape, + /*data=*/data_vector.data(), + /*len=*/data_vector.size() * sizeof(int32_t), + /*deleter=*/callback, &status); + ASSERT_TRUE(status.ok()) << status.message(); + } + // At this point, tensor has been destroyed, and the deleter callback should + // have run. + EXPECT_TRUE(done); +} + +} // namespace diff --git a/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h new file mode 100644 index 00000000000..af9cad7529b --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ + +#include + +#include "tensorflow/c/tf_datatype.h" + +namespace tensorflow { + +// Each of the following struct types have two members: a kDType that +// corresponds to a TF_Datatype enum value, and a typedef "type" +// of its corresponding C++ type. These types allow us to write Dtype-agnostic +// tests via GoogleTest's TypedTests: +// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests +struct FloatType { + using type = float; + static constexpr TF_DataType kDType = TF_FLOAT; +}; + +struct DoubleType { + using type = double; + static constexpr TF_DataType kDType = TF_DOUBLE; +}; + +struct Int32Type { + using type = int32_t; + static constexpr TF_DataType kDType = TF_INT32; +}; + +struct UINT8Type { + using type = uint8_t; + static constexpr TF_DataType kDType = TF_UINT8; +}; + +struct INT8Type { + using type = int8_t; + static constexpr TF_DataType kDType = TF_INT8; +}; + +struct INT64Type { + using type = int64_t; + static constexpr TF_DataType kDType = TF_INT64; +}; + +struct UINT16Type { + using type = uint16_t; + static constexpr TF_DataType kDType = TF_UINT16; +}; + +struct UINT32Type { + using type = uint32_t; + static constexpr TF_DataType kDType = TF_UINT32; +}; + +struct UINT64Type { + using type = uint64_t; + static constexpr TF_DataType kDType = TF_UINT64; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_ diff --git a/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc new file mode 100644 index 00000000000..cfeaba4e392 --- /dev/null +++ b/tensorflow/cc/experimental/base/tests/tensorhandle_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/experimental/base/public/tensorhandle.h" + +#include +#include + +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/runtime_builder.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" +#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::Status; +using tensorflow::experimental::cc::Tensor; +using tensorflow::experimental::cc::TensorHandle; + +using SimpleTypes = ::testing::Types< + tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type, + tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type, + tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>; + +template +class ConstructScalarTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes); + +// This test constructs a scalar tensor for each of the types in "SimpleTypes", +// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and +// verify the expected dims, dtype, value, num bytes, and num elements. +TYPED_TEST(ConstructScalarTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + typename TypeParam::type value = 42; + Tensor original_tensor = + Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{}, + /*data=*/&value, + /*len=*/sizeof(value), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 0); + EXPECT_EQ(tensor.dtype(), dtype); + EXPECT_EQ(*reinterpret_cast(tensor.data()), 42); + EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), 1); +} + +template +class Construct1DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes); + +// This test constructs a 1D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct1DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 1 vector. + std::vector shape; + shape.push_back(value.size()); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 1); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +template +class Construct2DTensorHandleTest : public ::testing::Test {}; +TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes); + +// This test constructs a 2D tensor for each of the types in "SimpleTypes", +// and verifies the expected dimensions, dtype, value, number of bytes, and +// number of elements. +TYPED_TEST(Construct2DTensorHandleTest, + ValidTensorAttributesAfterConstruction) { + Status status; + RuntimeBuilder runtime_builder; + std::unique_ptr runtime = runtime_builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + TF_DataType dtype = TypeParam::kDType; + // This is our 1D tensor of varying dtype. + std::vector value = {42, 100, 0, 1, 4, 29}; + // Shape is Rank 2 vector with shape 2 x 3. + std::vector shape({2, 3}); + + Tensor original_tensor = Tensor::FromBuffer( + /*dtype=*/dtype, /*shape=*/shape, + /*data=*/value.data(), + /*len=*/value.size() * sizeof(typename TypeParam::type), + /*deleter=*/[](void*, size_t) {}, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + TensorHandle handle = + TensorHandle::FromTensor(original_tensor, *runtime, &status); + ASSERT_TRUE(status.ok()) << status.message(); + + Tensor tensor = handle.Resolve(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(tensor.dims(), 2); + EXPECT_EQ(tensor.dtype(), dtype); + tensorflow::gtl::ArraySlice tensor_view( + reinterpret_cast(tensor.data()), value.size()); + EXPECT_EQ(tensor_view[0], 42); + EXPECT_EQ(tensor_view[1], 100); + EXPECT_EQ(tensor_view[2], 0); + EXPECT_EQ(tensor_view[3], 1); + EXPECT_EQ(tensor_view[4], 4); + EXPECT_EQ(tensor_view[5], 29); + + EXPECT_EQ(tensor.num_bytes(), + value.size() * sizeof(typename TypeParam::type)); + EXPECT_EQ(tensor.num_elements(), value.size()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 882b4032f76..b13d8db48a9 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -4,7 +4,6 @@ load( "//tensorflow:tensorflow.bzl", "if_android", - "if_ios", "if_mobile", "if_not_mobile", "tf_cc_test", @@ -85,7 +84,7 @@ cc_library( "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", ]) + if_android([ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ]), ) diff --git a/tensorflow/cc/saved_model/experimental/public/BUILD b/tensorflow/cc/saved_model/experimental/public/BUILD new file mode 100644 index 00000000000..3e9a671a61f --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/BUILD @@ -0,0 +1,58 @@ +# Experimental C++ SavedModel Header Only APIs. See RFC +# https://github.com/tensorflow/community/pull/207 + +package( + # This is intentionally public + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "concrete_function", + hdrs = [ + "concrete_function.h", + ], + deps = [ + ":function_metadata", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/experimental/saved_model/public:concrete_function", + "//tensorflow/cc/experimental/base/public:status", + ], +) + +cc_library( + name = "concrete_function_list", + hdrs = [ + "concrete_function_list.h", + ], + deps = [ + ":concrete_function", + "//tensorflow/c/experimental/saved_model/public:concrete_function_list", + ], +) + +cc_library( + name = "function_metadata", + hdrs = [ + "function_metadata.h", + ], + deps = [ + "//tensorflow/c/experimental/saved_model/public:function_metadata", + ], +) + +cc_library( + name = "saved_model_api", + hdrs = [ + "saved_model_api.h", + ], + deps = [ + ":concrete_function", + ":concrete_function_list", + "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:status", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/tensorflow/cc/saved_model/experimental/public/concrete_function.h new file mode 100644 index 00000000000..1adaf70b01a --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// ConcreteFunction is an executable "function" loaded from a SavedModelAPI. +class ConcreteFunction final { + public: + // TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since + // it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle + + // Returns FunctionMetadata associated with this ConcreteFunction. + const FunctionMetadata* GetFunctionMetadata(); + + private: + friend class SavedModelAPI; + friend class ConcreteFunctionList; + + // TODO(bmzhao): Consider adding a macro for wrapping/unwrapping + // when moving out of experimental. + static ConcreteFunction* wrap(TF_ConcreteFunction* p) { + return reinterpret_cast(p); + } + static TF_ConcreteFunction* unwrap(ConcreteFunction* p) { + return reinterpret_cast(p); + } +}; + +inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() { + return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this))); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h new file mode 100644 index 00000000000..88cb779ef15 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// ConcreteFunctionList helps convert an opaque pointer to an array of +// ConcreteFunction pointers to a std::vector. +class ConcreteFunctionList { + public: + // Converts this object to a std::vector + std::vector ToVector(); + + private: + friend class SavedModelAPI; + // Wraps a TF_ConcreteFunctionList. Takes ownership of list. + explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {} + + struct TFConcreteFunctionListDeleter { + void operator()(TF_ConcreteFunctionList* p) const { + TF_DeleteConcreteFunctionList(p); + } + }; + std::unique_ptr list_; +}; + +inline std::vector ConcreteFunctionList::ToVector() { + int size = TF_ConcreteFunctionListSize(list_.get()); + std::vector result; + result.reserve(size); + for (int i = 0; i < size; ++i) { + result.push_back( + ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i))); + } + return result; +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/tensorflow/cc/saved_model/experimental/public/function_metadata.h new file mode 100644 index 00000000000..11e1a860d84 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/function_metadata.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// FunctionMetadata stores additional function information, including +// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions), +// a valid function path (for TF2-based ConcreteFunctions), and +// the types + number of inputs and outputs. +class FunctionMetadata final { + // TODO(bmzhao): Add getters here as necessary. + private: + friend class ConcreteFunction; + static FunctionMetadata* wrap(TF_FunctionMetadata* p) { + return reinterpret_cast(p); + } + static TF_FunctionMetadata* unwrap(FunctionMetadata* p) { + return reinterpret_cast(p); + } +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h new file mode 100644 index 00000000000..04018bf2aab --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SavedModelAPI offers a way to load Tensorflow Saved Models +// (https://www.tensorflow.org/guide/saved_model) and execute saved +// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion. +// See RFC 207 +// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md) +// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added. +class SavedModelAPI { + public: + // Load a SavedModel from `dirname`. + // + // Params: + // saved_model_path - A directory filepath that the SavedModel is at. + // runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the + // returned TF_SavedModel pointer. + // tags - Optional set of tags. If tags = nullptr, we expect the SavedModel + // to contain a single Metagraph (as for those exported from TF2's + // `tf.saved_model.save`). If tags != nullptr, we load the metagraph + // matching the tags: + // https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. + static std::unique_ptr Load( + const std::string& saved_model_path, const Runtime& runtime, + Status* status, const std::unordered_set* tags = nullptr); + + // Retrieve a function from the TF2 SavedModel via function path. + // + // Params: + // function_path - A string containing the path from the root saved python + // object to a tf.function method. + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + ConcreteFunction* GetConcreteFunction(const std::string& function_path, + Status* status); + + // Retrieve a function from the TF SavedModel via a SignatureDef key. + // + // Params: + // signature_def_key - String key of SignatureDef map of a SavedModel: + // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + ConcreteFunction* GetSignatureDefFunction(const std::string& function_path, + Status* status); + + // Lists all Conrete Functions available from the SavedModel. + std::vector ListFunctions(); + + // SavedModelAPI is movable, but not copyable. + SavedModelAPI(SavedModelAPI&&) = default; + SavedModelAPI& operator=(SavedModelAPI&&) = default; + + private: + SavedModelAPI(const SavedModelAPI&) = delete; + SavedModelAPI& operator=(const SavedModelAPI&) = delete; + + explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {} + struct TFSavedModelDeleter { + void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); } + }; + std::unique_ptr saved_model_; +}; + +inline std::unique_ptr SavedModelAPI::Load( + const std::string& saved_model_path, const Runtime& runtime, Status* status, + const std::unordered_set* tags) { + TF_SavedModel* saved_model = nullptr; + + if (tags == nullptr) { + saved_model = + TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(), + status->GetTFStatus()); + } else { + std::vector tags_vector; + tags_vector.reserve(tags->size()); + for (const std::string& tag : *tags) { + tags_vector.push_back(tag.c_str()); + } + saved_model = TF_LoadSavedModelWithTags( + saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(), + tags_vector.size(), status->GetTFStatus()); + } + + if (!status->ok()) { + return nullptr; + } + + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new SavedModelAPI(saved_model)); +} + +inline ConcreteFunction* SavedModelAPI::GetConcreteFunction( + const std::string& function_path, Status* status) { + TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return ConcreteFunction::wrap(function); +} + +inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction( + const std::string& function_path, Status* status) { + TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return ConcreteFunction::wrap(function); +} + +inline std::vector SavedModelAPI::ListFunctions() { + ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get())); + return list.ToVector(); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/tensorflow/cc/saved_model/experimental/tests/BUILD b/tensorflow/cc/saved_model/experimental/tests/BUILD new file mode 100644 index 00000000000..f24bcfdee2a --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/tests/BUILD @@ -0,0 +1,22 @@ +# Tests for the C++ header-only SavedModelAPI. +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +tf_cc_test( + name = "saved_model_api_test", + srcs = [ + "saved_model_api_test.cc", + ], + deps = [ + "//tensorflow/cc/experimental/base/public:runtime", + "//tensorflow/cc/experimental/base/public:runtime_builder", + "//tensorflow/cc/experimental/base/public:status", + "//tensorflow/cc/saved_model/experimental/public:saved_model_api", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc new file mode 100644 index 00000000000..7f7f6b09a6d --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/experimental/public/saved_model_api.h" + +#include +#include +#include + +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/runtime_builder.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" + + +namespace { + +using tensorflow::experimental::cc::Runtime; +using tensorflow::experimental::cc::RuntimeBuilder; +using tensorflow::experimental::cc::SavedModelAPI; +using tensorflow::experimental::cc::Status; + +constexpr char kTestData[] = "cc/saved_model/testdata"; + +std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) { + return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(), + kTestData, saved_model_dir); +} + +// This value parameterized test allows us to test both TFRT +// and non TFRT runtimes. +// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests +class CPPSavedModelAPITest : public ::testing::TestWithParam {}; + +TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) { + Status status; + RuntimeBuilder builder; + bool use_tfrt = GetParam(); + if (use_tfrt) { + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + builder.SetUseTFRT(use_tfrt); + std::unique_ptr runtime = builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + std::unordered_set tags = {"serve"}; + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status, &tags); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(status.code(), TF_UNIMPLEMENTED); +} + +TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { + Status status; + RuntimeBuilder builder; + bool use_tfrt = GetParam(); + if (use_tfrt) { + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + builder.SetUseTFRT(use_tfrt); + std::unique_ptr runtime = builder.Build(&status); + ASSERT_TRUE(status.ok()) << status.message(); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + std::unique_ptr model = + SavedModelAPI::Load(model_dir, *runtime, &status); + + // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. + // That unblocks writing other tests that require a TF_SavedModel*, + // like loading a ConcreteFunction. This test at least checks that the + // C API builds and can be minimally run. + EXPECT_EQ(status.code(), TF_UNIMPLEMENTED); +} + +INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests, + CPPSavedModelAPITest, ::testing::Bool()); + +} // namespace + diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index c9a36b88795..e4df3090046 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -131,6 +131,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); std::vector dim_vars; string dim_sizes, indices; + int count = 1; if (shape.rank() == 0 || (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; @@ -140,6 +141,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, dim_vars.push_back(absl::StrCat("size_t dim", dim)); dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); indices += absl::StrCat("[dim", dim, "]"); + count *= shape.dimensions(dim); } } rewrites->push_back({"{{I}}", absl::StrCat(i)}); @@ -147,6 +149,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); + rewrites->push_back({"{{COUNT}}", absl::StrCat(count)}); return Status::OK(); } @@ -199,6 +202,12 @@ Status GenArgMethods(const tf2xla::Config& config, return (*static_cast( arg_data({{I}}))){{INDICES}}; } + int arg{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int arg{{NAME}}_count() const { + return {{COUNT}}; + } )"; *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.feed(i).name().empty()) { @@ -246,6 +255,12 @@ Status GenResultMethods(const tf2xla::Config& config, return (*static_cast( result_data({{I}}))){{INDICES}}; } + int result{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int result{{NAME}}_count() const { + return {{COUNT}}; + } )"; *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.fetch(i).name().empty()) { @@ -281,6 +296,12 @@ Status GenVariableMethods(const tf2xla::Config& config, return (*static_cast( arg_data({{I}}))){{INDICES}}; } + int var_{{NAME}}_size() const { + return {{COUNT}} * sizeof({{TYPE}}); + } + int var_{{NAME}}_count() const { + return {{COUNT}}; + } )"; const tf2xla::Variable& var = config.variable(i - config.feed_size()); rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : ""); diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index af58ca233f0..d011279dbb7 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -138,6 +138,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(0)))[dim0][dim1]; } + int arg0_size() const { + return 2 * sizeof(float); + } + int arg0_count() const { + return 2; + } void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); @@ -156,6 +162,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(0)))[dim0][dim1]; } + int arg_myfeed_size() const { + return 2 * sizeof(float); + } + int arg_myfeed_count() const { + return 2; + } void set_arg1_data(const void* data) { set_arg_data(1, data); @@ -174,6 +186,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(1)))[dim0][dim1]; } + int arg1_size() const { + return 12 * sizeof(tensorflow::int64); + } + int arg1_count() const { + return 12; + } // Result methods for managing output buffers. Buffers are in row-major order. // Must only be called after a successful Run call. There is a set of methods @@ -204,6 +222,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( result_data(0)))[dim0][dim1]; } + int result0_size() const { + return 30 * sizeof(tensorflow::uint32); + } + int result0_count() const { + return 30; + } tensorflow::uint32* result_myfetch_data() { return static_cast(result_data(0)); @@ -219,6 +243,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( result_data(0)))[dim0][dim1]; } + int result_myfetch_size() const { + return 30 * sizeof(tensorflow::uint32); + } + int result_myfetch_count() const { + return 30; + } // Methods for managing variable buffers. Buffers are in row-major order. // @@ -261,6 +291,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(2)))[0]; } + int var_myvar_readonly_size() const { + return 1 * sizeof(float); + } + int var_myvar_readonly_count() const { + return 1; + } void set_var_myvar_data(float* data) { set_arg_data(3, data); @@ -279,6 +315,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(3)))[0]; } + int var_myvar_size() const { + return 1 * sizeof(float); + } + int var_myvar_count() const { + return 1; + } void set_var_myvar2_data(tensorflow::int32* data) { set_arg_data(4, data); @@ -297,6 +339,12 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction { return (*static_cast( arg_data(4)))[dim0]; } + int var_myvar2_size() const { + return 5 * sizeof(tensorflow::int32); + } + int var_myvar2_count() const { + return 5; + } private: // Number of buffers for the compiled computation. diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 28d922f9e3c..bc8fac0e88f 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -251,7 +251,7 @@ cc_library( visibility = [":friends"], deps = select({ "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:portable_tensorflow_lib", ], "//conditions:default": [ "//tensorflow/core:graph", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index ff76786a66f..174250f18bd 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2078,6 +2078,8 @@ absl::flat_hash_set GetKnownXLAWhitelistOp() { "XlaSend", "XlaSharding", "XlaSort", + "XlaSpmdFullToShardShape", + "XlaSpmdShardToFullShape", "XlaSvd", "XlaWhile", "_Arg", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index c90e8dead76..62b0c0ab4cf 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" @@ -277,29 +278,25 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kParameter; - }); + bool are_args_supported = + absl::c_all_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kConstant || + arg.kind == XlaCompiler::Argument::kParameter; + }); const ConfigProto* config = ctx->function_library()->config_proto(); bool use_mlir = config && config->experimental().enable_mlir_bridge(); - // Use MLIR bridge if all the arguments are parameters. - // TODO(hinsu): Support other argument types instead of silently falling - // back to the XLA compiler. - if (!are_params || !use_mlir) { + // TODO(b/155596779): Understand the source of other argument types and + // depending on the source either support those or avoid these codepath. + if (!use_mlir || !are_args_supported) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } - absl::InlinedVector arg_shapes; - arg_shapes.reserve(args.size()); - for (const XlaCompiler::Argument& arg : args) { - arg_shapes.push_back(absl::get(arg.shape)); - } GraphDebugInfo debug_info; return CompileGraphToXlaHlo( - *graph, {arg_shapes.data(), arg_shapes.size()}, - options.device_type.type_string(), compile_options.use_tuple_arg, - *options.flib_def, debug_info, options.shape_representation_fn, result); + *graph, {args.data(), args.size()}, options.device_type.type_string(), + compile_options.use_tuple_arg, *options.flib_def, debug_info, + options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 34ff0c55615..17e4226405a 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -180,12 +180,10 @@ class XlaAssignVariableOp : public OpKernel { data::MakeIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ data::AnonymousIteratorHandleOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \ - data::AnonymousIteratorHandleOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("DeleteIterator").Device(DEVICE).HostMemory("deleter"), \ - data::DeleteIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \ + data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \ + data::DeleteIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index cb23137e7fe..9b5b0c209e5 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -31,7 +31,7 @@ filegroup( "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -523,7 +523,6 @@ cc_library( "@flatbuffers", "@llvm-project//llvm:analysis", "@llvm-project//llvm:support", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:TransformUtils", ], @@ -696,9 +695,9 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", "@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", @@ -710,6 +709,8 @@ tf_cc_binary( name = "flatbuffer_translate", deps = [ ":flatbuffer_translate_registeration", + # TODO(b/155809683): Link only necessary dialects. + "@llvm-project//mlir:AllPassesAndDialects", ], ) @@ -758,6 +759,13 @@ tf_cc_binary( ":tf_tfl_passes", ":tf_tfl_translate_cl_options", ":tf_to_tfl_flatbuffer", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + # TODO(b/155809683): Link only necessary dialects. + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/core:protos_all_cc", @@ -765,11 +773,6 @@ tf_cc_binary( "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", ], ) @@ -781,17 +784,19 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_lib", ":flatbuffer_translate_registeration", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + # TODO(b/155809683): Link only necessary dialects. + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index e9192388070..6a631b1433d 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1020,7 +1020,7 @@ Optional> Translator::BuildOperator( if (!inst->getMutableAttrDict().getAttrs().empty()) { os << " {"; bool first = true; - for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) { + for (auto& named_attr : inst->getAttrDictionary()) { os << (!first ? ", " : ""); first = false; named_attr.first.print(os); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 93bf1dcde53..a585b8e1520 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -20,7 +20,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/LoopLikeInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" include "tensorflow/compiler/mlir/lite/quantization/quantization.td" @@ -247,7 +247,14 @@ class TFL_TFTypesWithSameBits : Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; -class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo : +class TFL_TFOperandTypesWithSameBits : + And<[ + Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa()">, + CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>, + Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; + +class TFL_OperandIsNoneOrHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[ CPred<"$_op.getOperand(" # n # ").getType().isa()">, @@ -255,13 +262,13 @@ class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo : CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() <= " # m>]>>; -class TFL_OperandHasRankLessThanOrEqualTo : +class TFL_OperandHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() <= " # m>]>>; -class TFL_OperandHasRankGreaterThanOrEqualTo : +class TFL_OperandHasRankAtLeast : PredOpTrait<"operand " # n # " is at least " # m # "-D", Or<[TFL_OperandIsUnrankedPred, CPred<"$_op.getOperand(" # n # @@ -300,6 +307,18 @@ class TFL_TCresVTEtIsSameAsOp : And<[ "quant::QuantizedType::castToStorageType(" "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>; +// This is a quantization-aware version of TCresVTEtIsSameAsOp +class TFL_TCopVTEtAreSameAt : Or<[ + TCopVTEtAreSameAt<[i, j]>, + TFL_TFOperandTypesWithSameBits, + And<[ + SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))", + quant_QuantizedType.predicate>, + CPred<"quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(" # i # "))) == " + "quant::QuantizedType::castToStorageType(" + "getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>; + //===----------------------------------------------------------------------===// // TFL op common constraints. //===----------------------------------------------------------------------===// @@ -395,9 +414,9 @@ class TFL_ConvOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, TFL_TensorOf<[F32, QI8, QUI8]>:$filter, - TFL_TensorOfOrNone<[F32, I32]>:$bias, + TFL_TensorOfOrNone<[F32, I32, I64]>:$bias, I32Attr:$dilation_h_factor, I32Attr:$dilation_w_factor, TFL_AFAttr:$fused_activation_function, @@ -406,7 +425,7 @@ class TFL_ConvOp : I32Attr:$stride_w ); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output); let hasOptions = 0b1; } @@ -846,6 +865,40 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ }]; } +def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [ + NoSideEffect, + TFL_OperandHasAtleastRank<0, 2>, + TFL_OperandHasAtleastRank<1, 2>, + SameOperandsAndResultElementType]> { + + let summary = "Batch Matrix Multiply Operator"; + + let description = [{ +Performs a batched matrix multiplication on the inputs. Follows the +conventions of TensorFlow BatchMatMulV2, with support for unknown dimensions +in the batch dimensions and broadcasting. + + Inputs: + `inputs[0]`: required: input LHS + `inputs[1]`: required: input RHS + `adjoint_lhs`: optional: Transpose LHS (default false) + `adjoint_lhs`: optional: Transpose LHS (default false) + }]; + + let arguments = (ins + TFL_TensorOf<[F32]>:$x, + TFL_TensorOf<[F32]>:$y, + DefaultValuedAttr:$adj_x, + DefaultValuedAttr:$adj_y + ); + + let results = (outs + TFL_TensorOf<[F32]>:$output + ); + + let hasOptions = 1; +} + def TFL_GatherOp : TFL_Op<"gather", [ NoSideEffect, SameOperandsAndResultsScale, @@ -929,7 +982,11 @@ def TFL_ScatterNdOp : TFL_Op<"scatter_nd", [ // Same type check of lhs and rhs is handled by the ResultsBroadcastableShape trait. def TFL_LessEqualOp : TFL_Op<"less_equal", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Less_equal operator"; let description = [{ @@ -937,8 +994,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs, - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -951,9 +1008,12 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ let hasOptions = 0; } -def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", - [NoSideEffect]> { - let summary = "Local Response Normalization."; +def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [ + TFL_OperandHasRank<0, 4>, + SameOperandsAndResultShape, + SameOperandsAndResultType, + NoSideEffect]> { + let summary = "Local Response Normalization."; let description = [{ The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last @@ -970,7 +1030,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag }]; let arguments = (ins - TFL_TensorOf<[F32, QI8, QUI8]>:$input, + TFL_FpTensor:$input, I32Attr:$radius, F32Attr:$bias, F32Attr:$alpha, @@ -978,7 +1038,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag ); let results = (outs - TFL_TensorOf<[F32, QI8, QUI8]>:$output + TFL_FpTensor:$output ); let hasOptions = 1; @@ -1014,7 +1074,7 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [ NoSideEffect, TFL_OperandHasAtleastRank<0, 1>, PredOpTrait<"operand and result must have the same element type", - TCresVTEtIsSameAsOp<0, 0>>]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = [{ Returns a tensor with the provided diagonal and everything else padded with zeros. }]; @@ -1027,17 +1087,21 @@ def TFL_MatrixDiagOp : TFL_Op<"matrix_diag", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$diagonal + TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$diagonal ); let results = (outs - TFL_TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, TFL_Uint8, QUI8, QI8, TFL_Quint8]>:$output ); let hasOptions = 0; } -def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [NoSideEffect]> { +def TFL_MatrixSetDiagOp : TFL_Op<"matrix_set_diag", [ + TFL_OperandHasAtleastRank<0, 2>, + PredOpTrait<"input and result must have the same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect]> { let summary = [{ Returns a batched matrix tensor with new batched diagonal values. }]; @@ -1049,12 +1113,12 @@ innermost matrices. These will be overwritten by the values in `diagonal`. }]; let arguments = (ins - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$input, - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$diagonal + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$input, + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$diagonal ); let results = (outs - TensorOf<[F32, I32, I64, I8, QI8, QI16, QUI8, TFL_Uint8, TFL_Quint8]>:$output + TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QI16, QUI8, TFL_Quint8]>:$result ); let hasOptions = 0; @@ -1172,7 +1236,12 @@ larger than 0. } def TFL_NotEqualOp : TFL_Op<"not_equal", [ - ResultsBroadcastableShape, Commutative, NoSideEffect, NoQuantizableResult]> { + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + BinaryOpSameElementTypeConstraint, + ResultsBroadcastableShape, + Commutative, + NoSideEffect, + NoQuantizableResult]> { let summary = "Not_equal operator"; let description = [{ @@ -1180,8 +1249,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$lhs, + TFL_TensorOf<[I1, F32, I32, I64, QUI8, QI8, TFL_Quint8, TFL_Str]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -1250,7 +1319,7 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", PredOpTrait<"value and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 1>>, TFL_OperandHasRank<0, 1>, - TFL_OperandHasRankGreaterThanOrEqualTo<1, 2> + TFL_OperandHasRankAtLeast<1, 2> ]> { let summary = "Embedding lookup operator"; @@ -1468,7 +1537,11 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [ } def TFL_GreaterOp : TFL_Op<"greater", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Greater operator"; let description = [{ @@ -1476,10 +1549,10 @@ def TFL_GreaterOp : TFL_Op<"greater", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs); - let results = (outs AnyTensor:$output); + let results = (outs TFL_BoolTensor:$output); let builders = [TFL_ComparisonBinaryBuilder]; @@ -1488,9 +1561,12 @@ def TFL_GreaterOp : TFL_Op<"greater", [ let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; } -def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, - SameOperandsAndResultShape, - TFL_GpuTargetOp]> { +def TFL_HardSwishOp: TFL_Op<"hard_swish", [ + NoSideEffect, + SameOperandsAndResultShape, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_GpuTargetOp]> { let summary = "Hardswish activation function."; let description = [{ Computes hard-swish activation function @@ -1500,7 +1576,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect, let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input); - let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$out); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output); let hasOptions = 0; } @@ -1529,29 +1605,35 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect, let customOption = "L2NormOptions"; } -def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [ + SameOperandsAndResultShape, + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Leaky Relu operator"; - // TODO(jpienaar): Add type restriction. This op is only defined for - // restricted (floating point) types. let description = [{ Element-wise Leaky ReLU operator x -> x >= 0 ? x : (alpha * x) }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input, // Slope of the activation function at x < 0. F32Attr:$alpha ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let hasOptions = 0b1; } def TFL_LessOp : TFL_Op<"less", [ - ResultsBroadcastableShape, NoSideEffect, NoQuantizableResult]> { + ResultsBroadcastableShape, + BinaryOpSameElementTypeConstraint, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<0, 1, 4>, + NoSideEffect, + NoQuantizableResult]> { let summary = "Less operator"; let description = [{ @@ -1559,8 +1641,8 @@ def TFL_LessOp : TFL_Op<"less", [ }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$lhs, + TFL_TensorOf<[F32, I32, I64, QUI8, QI8, TFL_Quint8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -1621,6 +1703,8 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { def TFL_LogisticOp: TFL_Op<"logistic", [ NoSideEffect, + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultShape, // zero_point = 0 // scale = 1. / (max_value + 1) @@ -1633,9 +1717,9 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ Computes element-wise Sigmoid of input }]; - let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$x); - let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$y); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$y); } def TFL_LogOp: TFL_Op<"log", [ @@ -1656,10 +1740,11 @@ def TFL_LogOp: TFL_Op<"log", [ let hasFolder = 1; } -// TODO(b/130643170): Adds some constraint for the input/output element types. def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ NoSideEffect, SameOperandsAndResultShape, + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, // zero_point = max_value // scale = -log_softmax_output_min / (max_value + 1) FixedResultScale>, @@ -1672,9 +1757,9 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ input - log(reduce_sum(exp(input), dim)) }]; - let arguments = (ins AnyTensor:$input); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -1693,6 +1778,9 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and " TFL_TCresVTEtIsSameAsOp<0, 0>]>>; def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ + TFL_OperandHasRank<0, 4>, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, MaxPoolOperandAndResultConstraints, SameOperandsAndResultsScale, @@ -1707,7 +1795,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$input, TFL_PaddingAttr:$padding, I32Attr:$stride_w, I32Attr:$stride_h, @@ -1716,7 +1804,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ TFL_AFAttr:$fused_activation_function ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16, TFL_Quint8]>:$output); let hasOptions = 1; @@ -1748,7 +1836,11 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ let hasOptions = 0; } -def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> { +def TFL_MeanOp : TFL_Op<"mean", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + TFL_GpuTargetOp]> { let summary = "Mean operator"; let description = [{ @@ -1760,13 +1852,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_GpuTargetOp]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, TFL_TensorOf<[I32, I64]>:$axis, BoolAttr:$keep_dims ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output); + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Uint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -1787,14 +1879,14 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { let arguments = (ins TFL_TensorOf<[I32, I64]>:$indices, TFL_I32Tensor:$depth, - TFL_TensorOf<[F32, I32, I64, I1]>:$on_value, - TFL_TensorOf<[F32, I32, I64, I1]>:$off_value, + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$on_value, + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$off_value, I32Attr:$axis ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I1]>:$output + TFL_TensorOf<[F32, I32, I64, I1, I8, UI8]>:$output ); let hasOptions = 1; @@ -1808,11 +1900,11 @@ Rounds the values of a tensor to the nearest integer, element-wise. }]; let arguments = (ins - TFL_TensorOf<[F32]>:$x + TFL_FpTensor:$x ); let results = (outs - TFL_TensorOf<[F32]>:$y + TFL_FpTensor:$y ); } @@ -1998,7 +2090,11 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> { let hasFolder = 1; } -def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { +def TFL_PackOp : TFL_Op<"pack", [ + PredOpTrait<"values and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale]> { let summary = "Packs a list of tensors along a dimension into one tensor"; let description = [{ @@ -2029,14 +2125,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { }]; let arguments = (ins - TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$values, + TFL_VariadicTensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$values, - I32Attr:$values_count, + Confined:$values_count, I32Attr:$axis ); let results = (outs - TFL_TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output + TFL_TensorOf<[F32, I8, I16, I32, I64, UI8, QI8, QUI8, QI16, TFL_Quint8]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -2047,8 +2143,11 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> { } def TFL_PadOp : TFL_Op<"pad", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, TFL_OperandHasRank<1, 2>, TFL_OperandRankEquals1DimOfOperand<0, 1>, TFL_GpuTargetOp]> { @@ -2079,22 +2178,25 @@ def TFL_PadOp : TFL_Op<"pad", [ ``` }]; - let arguments = (ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$padding); - let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } def TFL_PadV2Op : TFL_Op<"padv2", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, TFL_OperandHasRank<1, 2>, TFL_OperandHasRank<2, 0>, TFL_OperandRankEquals1DimOfOperand<0, 1>, PredOpTrait<"input and constant value operands must have same element type", - TCopVTEtAreSameAt<[0, 2]>>]> { + TFL_TCopVTEtAreSameAt<0, 2>>]> { let summary = "Padding operator v2"; let description = [{ @@ -2125,11 +2227,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [ }]; let arguments = ( - ins TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, + ins TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$padding, - TFL_TensorOf<[F32, I8, I32, I64]>:$constant_values); + TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$constant_values); - let results = (outs TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, I32, I64, UI8, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -2157,9 +2259,21 @@ def TFL_PowOp : TFL_Op<"pow", [ResultsBroadcastableShape, let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, - TFL_GpuTargetOp, - SameOperandsAndResultsScale]> { +def TFL_PReluOp : TFL_Op<"prelu", [ + NoSideEffect, + ResultsBroadcastableShape, + TFL_GpuTargetOp, + TFL_OperandHasRankAtMost<0, 4>, + TFL_OperandHasRankAtMost<1, 4>, + BinaryOpSameElementTypeConstraint, + PredOpTrait<"input and output must have the same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + PredOpTrait<"'alpha' should have one less rank than 'input'.", + Or<[TFL_OperandIsUnrankedPred<0>, + TFL_OperandIsUnrankedPred<1>, + CPred<"$_op.getOperand(0).getType().cast().getRank() == " + "$_op.getOperand(1).getType().cast().getRank() " + "+ 1">]>>]> { let summary = "Parameterized Relu operator"; let description = [{ @@ -2172,11 +2286,11 @@ def TFL_PReluOp : TFL_Op<"prelu", [NoSideEffect, }]; let arguments = ( - ins TFL_TensorOf<[F32, QUI8]>:$input, - TFL_TensorOf<[F32, QUI8]>:$alpha + ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, + TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$alpha ); - let results = (outs TFL_TensorOf<[F32, QUI8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); let verifier = [{ return Verify(*this); }]; } @@ -2333,9 +2447,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, Computes element-wise reverse square root of input }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TFL_FpTensor:$x); - let results = (outs AnyTensor:$y); + let results = (outs TFL_FpTensor:$y); let hasFolder = 1; } @@ -2853,7 +2967,7 @@ def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [ SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_OperandHasRankLessThanOrEqualTo<0, 4> + TFL_OperandHasRankAtMost<0, 4> ]> { let summary = "DepthToSpace operator"; @@ -2965,7 +3079,8 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", let arguments = (ins TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input, TFL_TensorOf<[I32]>:$size, - BoolAttr:$align_corners + BoolAttr:$align_corners, + DefaultValuedAttr:$half_pixel_centers ); let results = (outs @@ -3189,7 +3304,7 @@ def TFL_QConstOp : Op:$output); let builders = [OpBuilder< "OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value", @@ -3250,9 +3365,11 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [ let results = (outs AnyTensor:$output); } -def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect, - SameOperandsAndResultType, - NoQuantizableResult]> { +def TFL_DensifyOp: TFL_Op<"densify", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoQuantizableResult]> { let summary = "Densify operator"; let description = [{ @@ -3814,7 +3931,7 @@ def TFL_NumericVerifyOp : Op:$input, + TFL_TensorOf<[QI8, QUI8, QI16, F16, TFL_Quint8]>:$input, TFL_TensorOf<[F32]>:$ref, // Attributes diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index c338b723a4a..51fcbb97360 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -146,6 +146,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer( saved_model_exported_names.begin(), saved_model_exported_names.end()); absl::Span exported_names(exported_names_in_vector); + if (exported_names.size() != 1) { + return errors::Unimplemented("Only support a single exported name."); + } + TF_ASSIGN_OR_RETURN(auto module, ImportSavedModel(model_flags.saved_model_dir(), model_flags.saved_model_version(), tags, diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc index 48c0345ff3d..6b5c894b7f5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc @@ -32,6 +32,7 @@ namespace mlir { namespace quant { constexpr int k8Bits = 8; +constexpr int k32Bits = 32; constexpr unsigned kSigned = quant::QuantizationFlags::Signed; DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) { @@ -39,20 +40,20 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) { i8_ = IntegerType::get(k8Bits, ctx_); i8_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k8Bits); i8_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k8Bits); + i32_ = IntegerType::get(k32Bits, ctx_); + i32_min_ = QuantizedType::getDefaultMinimumForInteger(kSigned, k32Bits); + i32_max_ = QuantizedType::getDefaultMaximumForInteger(kSigned, k32Bits); any_ = AnyQuantizedType(); qi8_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_, i8_max_); qi8n_ = AnyQuantizedType::get(kSigned, i8_, f32_, i8_min_ + 1, i8_max_); + qi32_ = AnyQuantizedType::get(kSigned, i32_, f32_, i32_min_, i32_max_); assert(qi8n_ == qi8n_); } -Optional DeviceTarget::GetKernelSpec(QuantizeRegionOp op) const { - auto kernel_specs_it = specs_.find(op.logical_kernel()); +Optional DeviceTarget::GetKernelSpec( + llvm::StringRef kernel, const KernelSpecs::Signature& signature) const { + auto kernel_specs_it = specs_.find(kernel); if (kernel_specs_it == specs_.end()) return llvm::None; - - KernelSpecs::Signature signature; - signature.reserve(op.input_specs().size() + op.output_specs().size()); - AppendToSignature(op.input_specs(), &signature); - AppendToSignature(op.output_specs(), &signature); return kernel_specs_it->getValue().Find(signature); } @@ -62,31 +63,38 @@ ScaleDecomposeFn DeviceTarget::GetDecomposeFn(QuantizeRegionOp op) const { return kernel_specs_it->second.GetDecomposeFn(); } +void DeviceTarget::AppendToSignature(Type spec, + KernelSpecs::Signature* signature) { + if (auto quant = spec.dyn_cast_or_null()) { + signature->push_back(AnyQuantizedType::get( + quant.getFlags(), quant.getStorageType(), quant.getExpressedType(), + quant.getStorageTypeMin(), quant.getStorageTypeMax())); + } else if (auto any = spec.dyn_cast_or_null()) { + signature->push_back(any); + } else { // float + signature->push_back(AnyQuantizedType()); + } +} + LogicalResult DeviceTarget::RegisterKernel( llvm::StringRef kernel, const KernelSpecs::Signature& signature, const ScaleFn& fn, const ScaleDecomposeFn& dfn) { return specs_[kernel].Add(signature, {ScaleConstraintType::CustomScale, fn}); } +namespace ph = std::placeholders; + LogicalResult DeviceTarget::RegisterKernel( llvm::StringRef kernel, const KernelSpecs::Signature& signature, const ScaleConstraintType constraint) { - return specs_[kernel].Add(signature, {constraint, {}}); -} - -void DeviceTarget::AppendToSignature(ArrayAttr specs_attr, - KernelSpecs::Signature* signature) const { - for (auto attr : specs_attr) { - Type spec = attr.cast().getValue(); - if (auto quant = spec.dyn_cast()) { - signature->push_back(AnyQuantizedType::get( - quant.getFlags(), quant.getStorageType(), quant.getExpressedType(), - quant.getStorageTypeMin(), quant.getStorageTypeMax())); - } else if (auto any = spec.dyn_cast()) { - signature->push_back(any); - } else { // float - signature->push_back({}); - } + if (failed(specs_[kernel].Add(signature, {constraint, {}}))) return failure(); + switch (constraint) { + case ScaleConstraintType::OutputInputSameScale: + specs_[kernel].WithImpl(std::bind(&DeviceTarget::DecomposeSameScale, + ph::_1, ph::_2, ph::_3, ph::_4)); + return success(); + default: + return failure(); } } @@ -119,7 +127,7 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale( input_multipliers->append(3, kUnitQuantizedMultiplier); // output multipliers - double real_multiplier = o_spec.getScale() / scale_product; + double real_multiplier = scale_product / o_spec.getScale(); output_multipliers->push_back(quant::QuantizeMultiplier(real_multiplier)); // output ranges @@ -134,5 +142,40 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale( return success(); } +LogicalResult DeviceTarget::DecomposeSameScale( + Operation* op, quant::QuantizedMultipliers* input_multipliers, + quant::QuantizedMultipliers* output_multipliers, + quant::QuantizedRanges* output_ranges) { + auto rop = llvm::dyn_cast(op); + if (!rop) return failure(); + + // input multipliers + for (int i = 0; i < op->getNumOperands(); ++i) { + input_multipliers->push_back(kUnitQuantizedMultiplier); + } + + // output multipliers + for (int i = 0; i < op->getNumResults(); ++i) { + output_multipliers->push_back(kUnitQuantizedMultiplier); + } + + auto o_spec = rop.output_specs()[0] + .cast() + .getValue() + .dyn_cast(); + if (!o_spec) return failure(); + + // output ranges + auto min = rop.getAttrOfType("min"); + auto max = rop.getAttrOfType("max"); + output_ranges->push_back(quant::CalculateQuantizedRange( + o_spec.getScale(), o_spec.getZeroPoint(), + (min ? absl::optional(min.getValueAsDouble()) : absl::nullopt), + (max ? absl::optional(max.getValueAsDouble()) : absl::nullopt), + o_spec.getStorageTypeMin(), o_spec.getStorageTypeMax())); + + return success(); +} + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.h b/tensorflow/compiler/mlir/lite/quantization/device_target.h index 65e0c5fe4a6..8ed43157df8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.h +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.h @@ -134,11 +134,18 @@ class DeviceTarget { explicit DeviceTarget(MLIRContext* ctx); // Retrieves the kernel spec for the quant region op. - Optional GetKernelSpec(quant::QuantizeRegionOp op) const; + Optional GetKernelSpec( + llvm::StringRef kernel, const KernelSpecs::Signature& signature) const; // Retrieves the scale decomposition function for the quant region op. ScaleDecomposeFn GetDecomposeFn(quant::QuantizeRegionOp op) const; + // converts specification to signature: + // - UniformedQuantizedType -> AnyQuantizedType + // - AnyQuantizedType (int) -> AnyQuantizedType + // - Float -> {} + static void AppendToSignature(Type spec, KernelSpecs::Signature* signature); + protected: // Adds the kernel spec with the custom scale function for the kernel. LogicalResult RegisterKernel(llvm::StringRef kernel, @@ -154,13 +161,6 @@ class DeviceTarget { // added before. KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; } - // converts specification to signature: - // - UniformedQuantizedType -> AnyQuantizedType - // - AnyQuantizedType (int) -> AnyQuantizedType - // - Float -> {} - void AppendToSignature(ArrayAttr specs_attr, - KernelSpecs::Signature* signature) const; - // For "mulmat->add" type of kernels, convert the scales of all the ports to // multipliers. static LogicalResult DecomposeMultiplyAccumulateScale( @@ -168,11 +168,17 @@ class DeviceTarget { quant::QuantizedMultipliers* output_multipliers, quant::QuantizedRanges* output_ranges); + // For "reshape" type of kernels. + static LogicalResult DecomposeSameScale( + Operation* op, quant::QuantizedMultipliers* input_multipliers, + quant::QuantizedMultipliers* output_multipliers, + quant::QuantizedRanges* output_ranges); + // A set of parameters are required to build the signatures. FloatType f32_; - IntegerType i8_; - int64_t i8_min_, i8_max_; - AnyQuantizedType any_, qi8_, qi8n_; + IntegerType i8_, i32_; + int64_t i8_min_, i8_max_, i32_min_, i32_max_; + AnyQuantizedType any_, qi8_, qi8n_, qi32_; private: // Maps the kernel names to all the available kernels. diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 1504f7d3a1b..b4fddceb580 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -72,5 +72,6 @@ tf_cc_binary( "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 9b49757fd3f..a2e3c065113 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace lite { @@ -38,7 +39,9 @@ namespace lite { TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool fully_quantize, + const tflite::TensorType& inference_type, + const std::unordered_set& operator_names, + bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter) { // TODO(b/142502494): remove this restriction by improving the `emit_adaptor` @@ -72,15 +75,18 @@ TfLiteStatus QuantizeModel( // Apply quantization passes PassManager pm(module->getContext()); TFL::QuantizationSpecs quant_specs; - quant_specs.inference_type = tensorflow::DT_QINT8; + quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; + quant_specs.disable_per_channel = disable_per_channel; bool emit_adaptor = false; auto input_tf_type = tflite::TflTypeToTfType(input_type); if (input_tf_type == tensorflow::DT_FLOAT) { emit_adaptor = true; - } else if (input_tf_type == tensorflow::DT_UINT8) { - quant_specs.inference_type = tensorflow::DT_QUINT8; + } else if (input_tf_type == tensorflow::DT_UINT8 || + input_tf_type == tensorflow::DT_INT8 || + input_tf_type == tensorflow::DT_INT16) { + quant_specs.inference_type = input_tf_type; } pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs)); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index 473e97e07df..d60df56b473 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -26,12 +26,15 @@ namespace mlir { namespace lite { // Quantize the `input_model` and write the result to a flatbuffer `builder`. -// The `input_type` and `output_type` can be float32/qint8/int8. +// The `input_type`, `output_type` and `inference_type` can be +// float32/qint8/int8/int16. // Return partially quantized model if `fully_quantize` is false. TfLiteStatus QuantizeModel( const tflite::ModelT& input_model, const tflite::TensorType& input_type, const tflite::TensorType& output_type, - const std::unordered_set& operator_names, bool fully_quantize, + const tflite::TensorType& inference_type, + const std::unordered_set& operator_names, + bool disable_per_channel, bool fully_quantize, flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter); } // namespace lite diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc index 7530cdf008f..5bd1b71e631 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_quantizer.cc @@ -46,7 +46,9 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer, tflite::StderrReporter error_reporter; return mlir::lite::QuantizeModel( - *model, tflite::TensorType_INT8, tflite::TensorType_INT8, {}, + *model, tflite::TensorType_INT8, tflite::TensorType_INT8, + tflite::TensorType_INT8, {}, + /*disable_per_channel=*/false, /*fully_quantize=*/true, builder, &error_reporter); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 5b1c73e7887..2ffba579548 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -46,6 +46,12 @@ struct QuantizationSpecs { // post-training quantization. We need to deprecate the `weight_quantization`. bool post_training_quantization = false; + // When set to true, quantization will be done per-tensor. Currently, this + // option is only valid when the quantization parameters need to be created by + // scanning the constant content (post-training quantization or QAT without + // weight FakeQuant). + bool disable_per_channel = false; + // The node type when the model is exported. Currently this is limited to // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, @@ -84,7 +90,7 @@ struct QuantizationSpecs { bool RunWeightQuantization() const { return weight_quantization; } // Whether this inference type represents a signed storage type. - bool IsSignedInferenceType() { + bool IsSignedInferenceType() const { switch (inference_type) { case tensorflow::DT_QUINT8: case tensorflow::DT_QUINT16: @@ -96,7 +102,7 @@ struct QuantizationSpecs { // Gets the width of this quantization type. Returns 0 if it isn't a // quantization type. - int64_t GetQuantizationTypeWidth() { + int64_t GetQuantizationTypeWidth() const { switch (inference_type) { case tensorflow::DT_QINT8: case tensorflow::DT_QUINT8: diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index 2b2c44f03a4..bcfd06cf06c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -64,10 +64,23 @@ std::vector QuantizeContext::GetAllOps() { return all_ops; } +KernelSpecs::Signature QuantizeContext::GetSignature(QuantizeRegionOp op) { + KernelSpecs::Signature signature; + signature.reserve(op.input_specs().size() + op.output_specs().size()); + for (int i = 0; i < op.getNumOperands(); ++i) { + DeviceTarget::AppendToSignature(GetOperandParams(op, i), &signature); + } + for (int i = 0; i < op.getNumResults(); ++i) { + DeviceTarget::AppendToSignature(GetResultParams(op, i), &signature); + } + return signature; +} + LogicalResult QuantizeContext::Handle( quant::QuantizeRegionOp op, llvm::SmallVectorImpl *new_items, bool *changed) { - auto spec = target_spec_.GetKernelSpec(op); + auto signature = GetSignature(op); + auto spec = target_spec_.GetKernelSpec(op.logical_kernel(), signature); if (!spec.hasValue()) { op.emitWarning( "Couldn't find kernel from the registeration for quantization."); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h index 0d460fd9a50..0c5137eb1a2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -107,6 +107,9 @@ class QuantizeContext { return states_manager_.GetOperandParams(op, index); } + // Return the signature of the op. + KernelSpecs::Signature GetSignature(QuantizeRegionOp op); + // A heuristic to get quantization parameters satisfies the same scale // constraints: // - If there are immutable states, diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 25ee1d8ba5d..15b6bf56b7a 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1213,15 +1213,14 @@ func @resize_nearest_neighbor(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi3 %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor // CHECK-LABEL: resize_nearest_neighbor - // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } -// Note: half_pixel_centers isn't supported by TFLite, so it's not legalized. func @resize_nearest_neighbor_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { - %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor // CHECK-LABEL: resize_nearest_neighbor_with_half_pixel_centers - // CHECK: "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} + // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor } func @sparse_to_dense_with_scalar_sparse_indices(%arg0: tensor, %arg1: tensor<3xi32>, %arg2: tensor, %arg3: tensor) -> tensor { @@ -1497,3 +1496,27 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3 // CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> // CHECK: return [[MUL]] : tensor<3x3xi32> } + +func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> + return %0 : tensor<10x17xf32> +// CHECK-LABEL: matmul_batch +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<10x15xf32>, tensor<15x17xf32>) -> tensor<10x17xf32> +} + +func @matmul_batchv2(%arg0: tensor<2x10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<2x10x17xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> + return %0 : tensor<2x10x17xf32> +// CHECK-LABEL: matmul_batchv2 +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x10x15xf32>, tensor<15x17xf32>) -> tensor<2x10x17xf32> +} + +func @matmul_batchv2_unknown_dim(%arg0: tensor, %arg1: tensor<15x17xf32>) -> tensor { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", adj_x = false, adj_y = false} : +(tensor, tensor<15x17xf32>) -> tensor + return %0 : tensor +// CHECK-LABEL: matmul_batchv2_unknown_dim +// CHECK: "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor<15x17xf32>) -> tensor +} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 38f736ee378..f42e06350e5 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -192,7 +192,7 @@ func @testSquare(tensor) -> tensor { func @testQuantizedResizeNearestNeighbor(tensor>, tensor) -> tensor> { ^bb0(%arg0: tensor>, %arg1: tensor): - %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false } : (tensor>, tensor) -> tensor> + %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor>, tensor) -> tensor> return %0 : tensor> } @@ -573,7 +573,7 @@ func @testLogistic(tensor<1x2x3x4x5xf32>) -> tensor<1x2x3x4x5xf32> { // test invalid Logistic input func @testLogisticWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or QUI16 type values}} + // expected-error @+1 {{'tfl.logistic' op operand #0 must be tensor of 32-bit float or QI8 type or QUI8 type or QI16 type or TFLite quint8 type values, but got 'tensor'}} %0 = "tfl.logistic"(%arg0): (tensor) -> tensor return %0#0 : tensor } @@ -1252,10 +1252,10 @@ func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, % // ----- -func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xi8> { - // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer values}} - %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi8> - return %0 : tensor<*xi8> +func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xi16> { + // expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit signless integer or 64-bit signless integer or 1-bit signless integer or 8-bit signless integer or 8-bit unsigned integer values, but got 'tensor<*xi16>'}} + %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi16> + return %0 : tensor<*xi16> } // ----- @@ -1489,7 +1489,8 @@ func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor>) -> tensor<1x56x56x192x!quant.uniform> { +func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> { + // expected-error @+1 {{'tfl.local_response_normalization' op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x56x56x192x!quant.uniform>'}} %0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> return %0 : tensor<1x56x56x192x!quant.uniform> } @@ -1523,32 +1524,32 @@ func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x // ----- -func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<1x1x10xf32>) -> tensor<10x10x10xf32> { - // expected-error @+1 {{'input' and 'output' should have the same rank}} - %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<1x1x10xf32>) -> tensor<10x10x10xf32> - return %0 : tensor<10x10x10xf32> +func @testPReluWrongOutputRank(%arg0: tensor<10x10x10x10xf32>, %arg1: tensor<10x10x10x10xf32>) -> tensor<10x10xf32> { + // expected-error @+1 {{'tfl.prelu' op result type '10x10' not broadcast compatible with broadcasted operands's shapes '10x10x10x10'}} + %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<10x10x10x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> + return %0 : tensor<10x10xf32> } // ----- func @testPReluWrongOutputShape(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> { - // expected-error @+1 {{'input' and 'output' should have the same shape}} + // expected-error @+1 {{'tfl.prelu' op result type '1x2x3x5' not broadcast compatible with broadcasted operands's shapes '1x2x3x4'}} %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<1x2x3x5xf32> return %0 : tensor<1x2x3x5xf32> } // ----- -func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> { +func @testPReluWrongAlphaRank(%arg0: tensor<7x3x2x14xf32>, %arg1: tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> { // expected-error @+1 {{'alpha' should have one less rank than 'input'.}} - %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<2x7x3x2x14xf32>) -> tensor<7x3x2x14xf32> + %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<7x3x2x14xf32>, tensor<7x3x2x14xf32>) -> tensor<7x3x2x14xf32> return %0 : tensor<7x3x2x14xf32> } // ----- func @testPReluInvalidBroadcast(%arg0: tensor<15x14x2x14xf32>, %arg1: tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> { - // expected-error @+1 {{'alpha' is not broadcastable at dimension 2.}} + // expected-error @+1 {{'tfl.prelu' op operands don't have broadcast-compatible shapes}} %0 = "tfl.prelu"(%arg0, %arg1) : (tensor<15x14x2x14xf32>, tensor<1x1x3xf32>) -> tensor<15x14x2x14xf32> return %0 : tensor<15x14x2x14xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index d1c0dd20c05..2815afd14b9 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -958,3 +958,16 @@ func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> // Fusing: %[[div2:[0-9].*]] = tfl.div %[[relu]], %[[div1]] {fused_activation_function = "RELU6"} : tensor<1xf32> // Fusing: return } + +func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %cst_1 = constant dense<2.0> : tensor<2x2xf32> + %0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: ReorderAddWithConstant + // CHECK: %[[CONST:.*]] = constant dense<3.000000e+00> : tensor<2x2xf32> + // CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32> +} + diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 5eefa821c6b..d3f1a430642 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, quant_specs.default_ranges.second.hasValue()) { pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( quant_specs.default_ranges.first.getValueOr(0.0), - quant_specs.default_ranges.second.getValueOr(0.0))); + quant_specs.default_ranges.second.getValueOr(0.0), + quant_specs.IsSignedInferenceType())); pass_manager->addPass(mlir::TFL::CreateQuantizePass()); pass_manager->addPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); @@ -73,16 +74,17 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); } + if (pass_config.shape_inference) { + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } + // Keep this pass after the shape inference pass, which couldn't do shape + // inference for non-tf ops. if (!pass_config.quant_specs.serialized_quant_stats.empty()) { pass_manager->addPass( mlir::quant::CreateImportQuantStatsPassForTFControlDialect( pass_config.quant_specs.serialized_quant_stats)); } - if (pass_config.shape_inference) { - pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); - } - // The conversion pipeline has to follow the following orders: // 1) Saved model related optimization like decompose resource ops // 2) Convert composite functions like lstm/rnns, along with proper function diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 4bc9d9e0c2d..fce1333a491 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -160,6 +160,11 @@ int main(int argc, char **argv) { absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); absl::Span exported_names(exported_names_vector); + if (exported_names.size() != 1) { + llvm::errs() << "There should be only one exported name"; + return kTrFailure; + } + module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags, exported_names, &context); } else { diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index b9ec67736d9..62f64ab63b4 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -174,7 +174,7 @@ StatusOr ImportSavedModel( return module; } else if (saved_model_version == 1) { auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, context); + input_filename, tags, exported_names, context); if (!module) return tensorflow::errors::InvalidArgument("fail to open input file"); diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index a1602baced5..c23ae9fcfab 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -46,8 +46,11 @@ namespace { class DefaultQuantParamsPass : public PassWrapper { public: - explicit DefaultQuantParamsPass(double default_min, double default_max) - : default_min_(default_min), default_max_(default_max) {} + explicit DefaultQuantParamsPass(double default_min, double default_max, + bool is_signed) + : default_min_(default_min), + default_max_(default_max), + is_signed_(is_signed) {} void runOnFunction() override; @@ -82,6 +85,7 @@ class DefaultQuantParamsPass double default_min_; double default_max_; + bool is_signed_; quant::QuantParams default_quant_params_; }; } // namespace @@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams( default_quant_params_ = quant::fakeQuantAttrsToType( builder.getUnknownLoc(), /*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false, - builder.getF32Type()); + builder.getF32Type(), is_signed_); } return default_quant_params_; } // Creates an instance of the default quant parameters pass. std::unique_ptr> CreateDefaultQuantParamsPass( - double default_min, double default_max) { - return absl::make_unique(default_min, default_max); + double default_min, double default_max, bool is_signed) { + return absl::make_unique(default_min, default_max, + is_signed); } // Registers this pass with default values, only for test @@ -230,7 +235,8 @@ static PassRegistration pass( "tfl-default-quant", "Apply quantization with default quantization parameter", [] { return CreateDefaultQuantParamsPass(/*default_min=*/-1.0, - /*default_max=*/1.0); + /*default_max=*/1.0, + /*is_signed=*/false); }); } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index 201a0bb2481..9b526f40277 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -321,7 +321,8 @@ void DenseToSparse::runOnFunction() { if (result.needs_densify) { const auto value = op->getOperand(operand); - auto densify = builder.create(op->getLoc(), value); + auto densify = + builder.create(op->getLoc(), value.getType(), value); value.replaceAllUsesWith(densify); densify.setOperand(value); } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 13ae216dc25..4c6a16c2233 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -211,6 +211,11 @@ def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>; def : Pat<(TF_AddOp $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; +// When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected +// with additional ops. In the case of unknown batch size, the match will +// fall through to here and convert to TF Lite BatchMatMul. +def : Pat<(TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; +def : Pat<(TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; def : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_MulOp $lhs, $rhs), (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; def : Pat<(TF_RealDivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; @@ -297,7 +302,7 @@ def : Pat<(TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format) (TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>; def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>; -def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners)>; +def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers)>; def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index ce0b49fbd49..49be29065fe 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -75,8 +75,6 @@ class TensorListPatternRewriter : public PatternRewriter { public: explicit TensorListPatternRewriter(FuncOp fn) : PatternRewriter(fn.getContext()) {} - - Operation *insert(Operation *op) override { return OpBuilder::insert(op); } }; /// Lower TensorList ops in functions for subsequent legalization. @@ -861,6 +859,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); // Register fused LSTM/RNN ops as legal. target.addLegalOp(); target.addLegalOp(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 82d9a76fab3..a3244f31053 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -457,3 +457,21 @@ def : Pat<(TFL_AddOp // The constant folding in this pass might produce constant in the tf dialect. // This rule is to legalize these constant to the tfl dialect. def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; + +// Reorders adds to allow constant folding. +// Add --> Add $input, $constantA +// \--> $constantB +// To +// Add --> $input +// \--> Add ($constantA, $constantB) +foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { + def : Pat<(TFL_AddOp + (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None), + (ConstantOp $b), ActFun), + (TFL_AddOp $input, + (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), + ActFun), + [(HasOneUse $first_output)]>; +} + + diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 959c17e317a..105c9394fb4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -76,7 +76,7 @@ std::unique_ptr> CreateOptimizeFunctionalOpsPass(); // Creates an instance of the TensorFlow Lite dialect pass to add default // quantization parameters. std::unique_ptr> CreateDefaultQuantParamsPass( - double default_min, double default_max); + double default_min, double default_max, bool is_signed); // Creates an instance of the TensorFlow Lite dialect pass to convert dense // tensor to sparse format. diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 4f25e434fac..a9e10a485bf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -273,8 +273,9 @@ void PrepareQuantizePass::runOnFunction() { // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). - ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel, - GetOpQuantSpec); + ApplyQuantizationParamsPropagation( + func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, + GetOpQuantSpec); ConvertMlirQuantOpsToTFLQuantOps(func); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index a97af8e632e..c5211bdfadb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -48,6 +48,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" @@ -612,11 +613,35 @@ struct ConvertTFStridedSlice : public RewritePattern { #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" +// Returns success if all the operations in the `op`'s regions including `op` +// itself are legal in a TFLite pipeline. +LogicalResult ValidateOp(Operation *op) { + bool has_illegal_ops = false; + op->walk([&](Operation *op) { + if (isa(op)) { + has_illegal_ops = true; + op->emitOpError() << "is illegal in a TFLite pipeline"; + } + }); + + return failure(has_illegal_ops); +} + void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); MLIRContext *ctx = &getContext(); + // Check illegal ops in a TFLite pipeline (e.g. trainning only ops) , since + // PrepareTFPass is the very first TFLite pass in the pipeline. + // TODO(jingpu): It might be better to split this check into its own pass + // to make things more modular. + if (failed(ValidateOp(func))) { + func.emitError() << "tfl-prepare-tf pass failed."; + signalPassFailure(); + return; + } + // This pattern was intented to uses TFL QDQs to preserve the quantization // parameters from the TF Quant ops, thus this pattern should run with the // first `applyPatternsAndFoldGreedily` method, which would otherwise removes diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 666f89ac72f..1189a926383 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -12,6 +12,22 @@ cc_library( "//tensorflow/c:tf_status_helper", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:error_util", + # (yongtang) The graph_optimization_pass_registration needs to be part + # of a shared object that will be loaded whenever `import tensorflow` + # is run. The natural place is libtensorflow_framework.so. + # While adding graph_optimization_pass_registration to + # libtensorflow_framework.so is possible with some modification in + # dependency, many tests will fail due to multiple copies of LLVM. + # See https://github.com/tensorflow/tensorflow/pull/39231 for details. + # Alternatively, we place graph_optimization_pass_registration here + # because: + # - tensorflow/python/_pywrap_mlir.so already depends on LLVM anyway + # - tensorflow/python/_pywrap_mlir.so always loaded as part of python + # binding + # TODO: It might be still preferrable to place graph_optimization_pass + # as part of the libtensorflow_framework.so, as it is the central + # place for core related components. + "//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration", "//tensorflow/compiler/mlir/tensorflow:import_utils", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index d0f6e015922..f22fb519a64 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -112,7 +112,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context; - auto module_or = ConvertSavedModelV1ToMlir(bundle, &context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context); if (!module_or.status().ok()) { Set_TF_Status_from_Status(status, module_or.status()); return "// error"; diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD new file mode 100644 index 00000000000..78f4312da46 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -0,0 +1,41 @@ +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") + +package(licenses = ["notice"]) + +tf_python_pybind_extension( + name = "mlir_wrapper", + srcs = [ + "attrs.cc", + "basic_classes.cc", + "builders.cc", + "mlir_wrapper.cc", + "mlir_wrapper.h", + "ops.cc", + "types.cc", + ], + module_name = "mlir_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@pybind11", + ], +) + +tf_python_pybind_extension( + name = "filecheck_wrapper", + srcs = ["filecheck_wrapper.cc"], + module_name = "filecheck_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:support", + "@pybind11", + ], +) diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc new file mode 100644 index 00000000000..ca7faf2e1d3 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/attrs.cc @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_attrs(py::module& m) { + py::class_(m, "Attribute"); + py::class_(m, "IntegerAttr") + .def("get", + py::overload_cast(&mlir::IntegerAttr::get)); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc new file mode 100644 index 00000000000..25adb44fe1d --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/basic_classes.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_basic_classes(py::module& m) { + py::class_(m, "MLIRContext").def(py::init<>()); + + py::class_(m, "Location"); + + py::class_(m, "UnknownLoc") + .def("get", &mlir::UnknownLoc::get); + + py::class_(m, "Region") + .def("back", &mlir::Region::back, py::return_value_policy::reference) + .def("front", &mlir::Region::front, py::return_value_policy::reference) + .def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); }) + .def("push_back", &mlir::Region::push_back) + .def("size", [](mlir::Region& r) { return r.getBlocks().size(); }) + .def("front", &mlir::Region::front, py::return_value_policy::reference); + py::class_(m, "Block_Iterator"); + py::class_(m, "Block") + .def("new", ([]() { return new mlir::Block; }), + py::return_value_policy::reference) + .def("end", &mlir::Block::end) + .def("addArgument", &mlir::Block::addArgument); + + py::class_(m, "Value").def("getType", &mlir::Value::getType); + py::class_(m, "OpResult"); + py::class_(m, "BlockArgument"); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc new file mode 100644 index 00000000000..338f17ed6df --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/builders.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/Builders.h" // from @llvm-project + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +void init_builders(py::module& m) { + py::class_(m, "Builder") + .def(py::init()) + .def("getFunctionType", + [](mlir::Builder& b, std::vector inputs, + std::vector outputs) { + return b.getFunctionType(llvm::ArrayRef(inputs), + llvm::ArrayRef(outputs)); + }); + py::class_(m, "OpBuilder") + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc) + .def("setInsertionPoint", + py::overload_cast( + &mlir::OpBuilder::setInsertionPoint)) + .def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint) + .def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint) + .def( + "createOperation", + [](mlir::OpBuilder& opb, mlir::OperationState& state) { + return opb.createOperation(state); + }, + py::return_value_policy::reference) + .def("getContext", &mlir::OpBuilder::getContext, + py::return_value_policy::reference); + + py::class_(m, "OpBuilder_InsertionPoint") + .def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc new file mode 100644 index 00000000000..8a841856b72 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/filecheck_wrapper.cc @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/FileCheck.h" +#include "llvm/Support/SourceMgr.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(filecheck_wrapper, m) { + m.def("check", [](std::string input, std::string check) { + llvm::FileCheckRequest fcr; + llvm::FileCheck fc(fcr); + llvm::SourceMgr SM = llvm::SourceMgr(); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), + llvm::SMLoc()); + SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check), + llvm::SMLoc()); + llvm::Regex regex = fc.buildCheckPrefixRegex(); + fc.readCheckFile(SM, llvm::StringRef(check), regex); + return fc.checkInput(SM, llvm::StringRef(input)); + }); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc new file mode 100644 index 00000000000..6f468cd4267 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(mlir_wrapper, m) { + m.def("registerDialects", []() { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + }); + + init_basic_classes(m); + init_types(m); + init_builders(m); + init_ops(m); + init_attrs(m); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h new file mode 100644 index 00000000000..562c59b43e1 --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H +#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +void init_basic_classes(py::module& m); +void init_types(py::module& m); +void init_builders(py::module& m); +void init_ops(py::module& m); +void init_attrs(py::module& m); + +#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc new file mode 100644 index 00000000000..4432829653e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/ops.cc @@ -0,0 +1,194 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project + +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +void init_ops(py::module& m) { + py::class_>( + m, "Operation") + .def("getRegion", &mlir::Operation::getRegion, + py::return_value_policy::reference) + .def("getResult", &mlir::Operation::getResult) + .def("dump", &mlir::Operation::dump) + .def("getNumResults", &mlir::Operation::getNumResults); + + py::class_(m, "OperationState") + .def(py::init([](mlir::Location loc, std::string name) { + return mlir::OperationState(loc, llvm::StringRef(name)); + })) + .def("addTypes", + [](mlir::OperationState& state, std::vector tys) { + state.addTypes(mlir::ArrayRef(tys)); + }) + .def("addOperands", + [](mlir::OperationState& os, std::vector ops) { + os.addOperands(mlir::ArrayRef(ops)); + }) + .def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion), + py::return_value_policy::reference); + + py::class_(m, "ModuleOp") + .def("create", + [](mlir::Location loc) { return mlir::ModuleOp::create(loc); }) + .def("push_back", + [](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); }) + .def("dump", &mlir::ModuleOp::dump) + .def("getAsStr", [](mlir::ModuleOp& m) { + std::string str; + llvm::raw_string_ostream os(str); + m.print(os); + return os.str(); + }); + + py::class_(m, "FuncOp") + .def("create", + [](mlir::Location location, std::string name, + mlir::FunctionType type) { + auto func = mlir::FuncOp::create(location, name, type); + func.addEntryBlock(); + return func; + }) + .def( + "getBody", + [](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); }, + py::return_value_policy::reference) + .def("getArguments", + [](mlir::FuncOp& f) { return f.getArguments().vec(); }) + .def("getName", [](mlir::FuncOp& f) { return f.getName().str(); }) + .def("getType", &mlir::FuncOp::getType); + + py::class_(m, "ReturnOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector values) -> mlir::Operation* { + return opb + .create(loc, + mlir::ArrayRef(values)) + .getOperation(); + }); + + // mlir::TF::AddOp + py::class_(m, "Tf_AddV2Op") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + py::class_(m, "Tf_AnyOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input, + mlir::Value reduction_indices, + bool keep_dims = false) -> mlir::Operation* { + return opb + .create(loc, opb.getI1Type(), input, + reduction_indices, keep_dims) + .getOperation(); + }); + + // mlir::TF::ConstOp + py::class_(m, "Tf_ConstOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Attribute value) -> mlir::Operation* { + return opb.create(loc, value).getOperation(); + }); + + // mlir::TF::EqualOp + py::class_(m, "Tf_EqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb + .create(loc, x, y, opb.getBoolAttr(true)) + .getOperation(); + }); + + // mlir::TF::GreaterEqualOp + py::class_(m, "Tf_GreaterEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y) + .getOperation(); + }); + + // mlir::TF::GreaterOp + py::class_(m, "Tf_GreaterOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LegacyCallOp + py::class_(m, "Tf_LegacyCallOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + std::vector output, std::vector args, + std::string f) -> mlir::Operation* { + return opb + .create( + loc, mlir::ArrayRef(output), + mlir::ArrayRef(args), mlir::StringRef(f)) + .getOperation(); + }); + + // mlir::TF::LessEqualOp + py::class_(m, "Tf_LessEqualOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::LessOp + py::class_(m, "Tf_LessOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); + + // mlir::TF::NegOp + py::class_(m, "Tf_NegOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, + mlir::Value x) -> mlir::Operation* { + return opb.create(loc, x).getOperation(); + }); + + py::class_(m, "Tf_NotEqualOp") + .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) { + return opb + .create( + loc, x, y, mlir::BoolAttr::get(true, opb.getContext())) + .getOperation(); + }); + + // mlir::TF::SubOp + py::class_(m, "Tf_SubOp") + .def("create", + [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, + mlir::Value y) -> mlir::Operation* { + return opb.create(loc, x, y).getOperation(); + }); +} diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc new file mode 100644 index 00000000000..2be67f8e93e --- /dev/null +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +void init_types(py::module& m) { + // Type + py::class_ Type(m, "Type"); + Type.def("getKind", &mlir::Type::getKind); + + // Type Enums + py::enum_(Type, "StandardTypes_Kind") + .value("BF16", mlir::StandardTypes::BF16); + + // Type Sub-classes + py::class_(m, "FunctionType") + .def("getResults", + [](mlir::FunctionType& ft) { return ft.getResults().vec(); }); + + py::class_(m, "FloatType") + .def("get", &mlir::FloatType::get); + + py::class_(m, "IntegerType") + .def("get", py::overload_cast( + &mlir::IntegerType::get)); + + py::class_(m, "UnrankedTensorType") + .def("get", &mlir::UnrankedTensorType::get); + + py::class_(m, "RankedTensorType") + .def("get", [](std::vector shape, mlir::Type ty) { + return mlir::RankedTensorType::get(mlir::ArrayRef(shape), ty); + }); +} diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 6d3131a781c..f1271d0da24 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [ ] tool_names = [ 'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', - 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', - 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', - 'xla-opt' + 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', + 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', + 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 661e6200df3..3e7596c75d7 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir', 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/tensorflow', + 'tensorflow/compiler/mlir/tfjs', 'tensorflow/compiler/mlir/xla', 'tensorflow/compiler/aot', 'tensorflow/compiler/xla/service/mlir_gpu', diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d1fb4343d51..54b560ed6ce 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -36,7 +36,7 @@ filegroup( "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -342,6 +342,38 @@ cc_library( ], ) +gentbl( + name = "tf_data_optimization_inc_gen", + tbl_outs = [ + ( + "-gen-rewriters", + "transforms/generated_tf_data_optimization.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/tf_data_optimization.td", + td_srcs = [ + ":tensorflow_ops_td_files", + "@llvm-project//mlir:StdOpsTdFiles", + ], +) + +cc_library( + name = "tf_data_optimization", + srcs = [ + "transforms/tf_data_optimization.cc", + ], + hdrs = [ + "transforms/tf_data_optimization.h", + ], + deps = [ + ":tensorflow", + ":tensorflow_types", + ":tf_data_optimization_inc_gen", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "unroll_batch_matmul_pass", srcs = [ @@ -406,10 +438,12 @@ cc_library( "transforms/tensor_array_ops_decomposition.cc", "transforms/tensor_list_ops_decomposition.cc", "transforms/test_side_effect_analysis.cc", + "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", "transforms/tpu_cluster_formation.cc", "transforms/tpu_dynamic_layout_pass.cc", "transforms/tpu_dynamic_padding_mapper.cc", + "transforms/tpu_extract_head_tail_outside_compilation.cc", "transforms/tpu_extract_outside_compilation.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_rewrite_pass.cc", @@ -443,6 +477,7 @@ cc_library( ":tensorflow", ":tensorflow_optimize_inc_gen", ":tensorflow_types", + ":tf_data_optimization", ":tpu_rewrite_device_util", ":translate_utils", ":unroll_batch_matmul_pass", @@ -521,7 +556,7 @@ cc_library( deps = [ ":tensorflow", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LoopOpsTransforms", + "@llvm-project//mlir:SCFTransforms", ], alwayslink = 1, ) @@ -599,7 +634,6 @@ cc_library( ":error_util", ":parse_text_proto", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", ], @@ -789,6 +823,7 @@ cc_library( ":mangling_util", ":tensorflow_attributes", ":tensorflow_types", + "//tensorflow/compiler/xla:util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -811,8 +846,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/stream_executor/lib", "@llvm-project//mlir:IR", ], @@ -1038,7 +1075,7 @@ genrule( srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", "ir/tf_generated_ops.td", "ir/tf_op_base.td", @@ -1111,6 +1148,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/stream_executor/lib", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + ":convert_tensor", ] # Prefer to link 'compile_mlir_util' library that also links necessary diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 89d40566b29..aa1601c4032 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -160,6 +160,8 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_AllOp : TF_Op<"All", [NoSideEffect]> { @@ -190,6 +192,44 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } +def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> { + let summary = "An Op to exchange data across TPU replicas."; + + let description = [{ +On each replica, the input is split into `split_count` blocks along +`split_dimension` and send to the other replicas given group_assignment. After +receiving `split_count` - 1 blocks from other replicas, we concatenate the +blocks along `concat_dimension` as the output. + +For example, suppose there are 2 TPU replicas: +replica 0 receives input: `[[A, B]]` +replica 1 receives input: `[[C, D]]` + +group_assignment=`[[0, 1]]` +concat_dimension=0 +split_dimension=1 +split_count=2 + +replica 0's output: `[[A], [C]]` +replica 1's output: `[[B], [D]]` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + I32Tensor:$group_assignment, + + I64Attr:$concat_dimension, + I64Attr:$split_dimension, + I64Attr:$split_count + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AngleOp : TF_Op<"Angle", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Returns the argument of a complex number."; @@ -1063,6 +1103,26 @@ for dtype in dtype_list: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_BroadcastArgsOp : TF_Op<"BroadcastArgs", [NoSideEffect]> { + let summary = "Return the shape of s0 op s1 with broadcast."; + + let description = [{ +Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the +broadcasted shape. `s0`, `s1` and `r0` are all integer vectors. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$s0, + TF_I32OrI64Tensor:$s1 + ); + + let results = (outs + TF_I32OrI64Tensor:$r0 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BroadcastGradientArgsOp : TF_Op<"BroadcastGradientArgs", [NoSideEffect]> { let summary = [{ Return the reduction indices for computing gradients of s0 op s1 with broadcast. @@ -1195,7 +1255,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> { +def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> { let summary = "Clips tensor values to a specified min and max."; let description = [{ @@ -1386,6 +1446,30 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] let hasCanonicalizer = 1; } +def TF_ConjugateTransposeOp : TF_Op<"ConjugateTranspose", [NoSideEffect]> { + let summary = [{ +Shuffle dimensions of x according to a permutation and conjugate the result. + }]; + + let description = [{ +The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: + `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` + `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])` + }]; + + let arguments = (ins + TF_Tensor:$x, + TF_I32OrI64Tensor:$perm + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>; +} + def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> { let summary = [{ Computes a 2-D convolution given 4-D `input` and `filter` tensors. @@ -1660,7 +1744,28 @@ Given an input tensor, this function computes hyperbolic cosine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> { +def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> { + let summary = "Compute the pairwise cross product."; + + let description = [{ +`a` and `b` must be the same shape; they can either be simple 3-element vectors, +or any shape where the innermost dimension is 3. In the latter case, each pair +of corresponding 3-element vectors is cross-multiplied independently. + }]; + + let arguments = (ins + TF_IntOrFpTensor:$a, + TF_IntOrFpTensor:$b + ); + + let results = (outs + TF_IntOrFpTensor:$product + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { let summary = "An Op to sum inputs across replicated TPU instances."; let description = [{ @@ -1684,7 +1789,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> { +def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; let description = [{ @@ -1734,6 +1839,169 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Returns the dimension index in the destination data format given the one in + }]; + + let description = [{ +the source data format. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$x, + + DefaultValuedAttr:$src_format, + DefaultValuedAttr:$dst_format + ); + + let results = (outs + TF_I32OrI64Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_DecodeAndCropJpegOp : TF_Op<"DecodeAndCropJpeg", [NoSideEffect]> { + let summary = "Decode and Crop a JPEG-encoded image to a uint8 tensor."; + + let description = [{ +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the JPEG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. + +If needed, the JPEG-encoded image is transformed to match the requested number +of color channels. + +The attr `ratio` allows downscaling the image by an integer factor during +decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +downscaling the image later. + + +It is equivalent to a combination of decode and crop, but much faster by only +decoding partial jpeg image. + }]; + + let arguments = (ins + TF_StrTensor:$contents, + I32Tensor:$crop_window, + + DefaultValuedAttr:$channels, + DefaultValuedAttr:$ratio, + DefaultValuedAttr:$fancy_upscaling, + DefaultValuedAttr:$try_recover_truncated, + DefaultValuedAttr:$acceptable_fraction, + StrAttr:$dct_method + ); + + let results = (outs + TF_Uint8Tensor:$image + ); +} + +def TF_DecodeGifOp : TF_Op<"DecodeGif", [NoSideEffect]> { + let summary = "Decode the frame(s) of a GIF-encoded image to a uint8 tensor."; + + let description = [{ +GIF images with frame or transparency compression are not supported. +On Linux and MacOS systems, convert animated GIFs from compressed to +uncompressed by running: + + convert $src.gif -coalesce $dst.gif + +This op also supports decoding JPEGs and PNGs, though it is cleaner to use +`tf.io.decode_image`. + }]; + + let arguments = (ins + TF_StrTensor:$contents + ); + + let results = (outs + TF_Uint8Tensor:$image + ); +} + +def TF_DecodeJpegOp : TF_Op<"DecodeJpeg", [NoSideEffect]> { + let summary = "Decode a JPEG-encoded image to a uint8 tensor."; + + let description = [{ +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the JPEG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. + +If needed, the JPEG-encoded image is transformed to match the requested number +of color channels. + +The attr `ratio` allows downscaling the image by an integer factor during +decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than +downscaling the image later. + + +This op also supports decoding PNGs and non-animated GIFs since the interface is +the same, though it is cleaner to use `tf.io.decode_image`. + }]; + + let arguments = (ins + TF_StrTensor:$contents, + + DefaultValuedAttr:$channels, + DefaultValuedAttr:$ratio, + DefaultValuedAttr:$fancy_upscaling, + DefaultValuedAttr:$try_recover_truncated, + DefaultValuedAttr:$acceptable_fraction, + StrAttr:$dct_method + ); + + let results = (outs + TF_Uint8Tensor:$image + ); +} + +def TF_DecodePngOp : TF_Op<"DecodePng", [NoSideEffect]> { + let summary = "Decode a PNG-encoded image to a uint8 or uint16 tensor."; + + let description = [{ +The attr `channels` indicates the desired number of color channels for the +decoded image. + +Accepted values are: + +* 0: Use the number of channels in the PNG-encoded image. +* 1: output a grayscale image. +* 3: output an RGB image. +* 4: output an RGBA image. + +If needed, the PNG-encoded image is transformed to match the requested number +of color channels. + +This op also supports decoding JPEGs and non-animated GIFs since the interface +is the same, though it is cleaner to use `tf.io.decode_image`. + }]; + + let arguments = (ins + TF_StrTensor:$contents, + + DefaultValuedAttr:$channels + ); + + let results = (outs + TensorOf<[TF_Uint16, TF_Uint8]>:$image + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_DepthToSpaceOp : TF_Op<"DepthToSpace", [NoSideEffect]> { let summary = "DepthToSpace for tensors of type T."; @@ -1963,6 +2231,8 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>, @@ -2195,6 +2465,51 @@ See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_EluGradOp : TF_Op<"EluGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes gradients for the exponential linear (Elu) operation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$outputs + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_EmptyOp : TF_Op<"Empty", []> { + let summary = [{ +Creates a tensor with the given shape. + +This operation creates a tensor of `shape` and `dtype`. + }]; + + let description = [{ + }]; + + let arguments = (ins + I32Tensor:$shape, + + DefaultValuedAttr:$init + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; + + let hasFolder = 1; +} + def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; @@ -2592,6 +2907,8 @@ fill([2, 3], 9) ==> [[9, 9, 9] return Verify(*this); }]; + let hasFolder = 1; + let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value dims, Value value" >]; @@ -3024,8 +3341,8 @@ Gather slices from `params` axis `axis` according to `indices`. let description = [{ `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -Produces an output tensor with shape `params.shape[:axis] + indices.shape + -params.shape[axis + 1:]` where: +Produces an output tensor with shape `params.shape[:axis] + +indices.shape[batch_dims:] + params.shape[axis + 1:]` where: ```python # Scalar indices (output is rank(params) - 1). @@ -3597,6 +3914,28 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType let hasFolder = 1; } +def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes rectified linear gradients for a LeakyRelu operation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$features, + + DefaultValuedAttr:$alpha + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_LeftShiftOp : TF_Op<"LeftShift", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Elementwise computes the bitwise left-shift of `x` and `y`."; @@ -3988,7 +4327,7 @@ cublas. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> { +def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> { let summary = [{ Copy a tensor setting everything outside a central band in each innermost matrix to zero. }]; @@ -4895,7 +5234,7 @@ func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32> @tf.function def foo(x, y): - return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32]) + return mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32]) graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def() ``` @@ -4960,6 +5299,8 @@ def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShap ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasFolder = 1; } def TF_MulNoNanOp : TF_Op<"MulNoNan", [NoSideEffect, ResultsBroadcastableShape]>, @@ -4974,12 +5315,12 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or }]; let arguments = (ins - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y + TF_FpOrComplexTensor:$x, + TF_FpOrComplexTensor:$y ); let results = (outs - TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z + TF_FpOrComplexTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -6032,6 +6373,29 @@ is the corresponding input gradient. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", []> { + let summary = "An op that receives embedding activations on the TPU."; + + let description = [{ +The TPU system performs the embedding lookups and aggregations specified by +the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The +results of these aggregations are visible to the Tensorflow Graph as the +outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing +one Tensor of activations per table specified in the model. There can be at +most one RecvTPUEmbeddingActivations op in the TPU graph. + }]; + + let arguments = (ins + StrAttr:$config + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; +} + def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> { let summary = "Computes rectified linear: `max(features, 0)`."; @@ -6070,6 +6434,24 @@ def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_Relu6GradOp : TF_Op<"Relu6Grad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear 6 gradients for a Relu6 operation."; + + let description = [{ + }]; + + let arguments = (ins + TF_IntOrFpTensor:$gradients, + TF_IntOrFpTensor:$features + ); + + let results = (outs + TF_IntOrFpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ReluGradOp : TF_Op<"ReluGrad", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear gradients for a Relu operation."; @@ -7077,6 +7459,52 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { ]; } +def TF_SeluOp : TF_Op<"Selu", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)` + }]; + + let description = [{ +if < 0, `scale * features` otherwise. + +To be used together with +`initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`. +For correct dropout, use `tf.contrib.nn.alpha_dropout`. + +See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) + }]; + + let arguments = (ins + TF_FpTensor:$features + ); + + let results = (outs + TF_FpTensor:$activations + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SeluGradOp : TF_Op<"SeluGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes gradients for the scaled exponential linear (Selu) operation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_FpTensor:$gradients, + TF_FpTensor:$outputs + ); + + let results = (outs + TF_FpTensor:$backprops + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> { let summary = "Returns the shape of a tensor."; @@ -7715,6 +8143,26 @@ I.e., \\(y = \sqrt{x} = x^{1/2}\\). TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SqrtGradOp : TF_Op<"SqrtGrad", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the gradient for the sqrt of `x` wrt its input."; + + let description = [{ +Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` +is the corresponding input gradient. + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$y, + TF_FpOrComplexTensor:$dy + ); + + let results = (outs + TF_FpOrComplexTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SquareOp : TF_Op<"Square", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes square of x element-wise."; @@ -8096,6 +8544,8 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_SumOp : TF_Op<"Sum", [NoSideEffect]> { @@ -9270,6 +9720,30 @@ y + truncate_mod(x, y) = x`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_TruncatedNormalOp : TF_Op<"TruncatedNormal", []> { + let summary = "Outputs random values from a truncated normal distribution."; + + let description = [{ +The generated values follow a normal distribution with mean 0 and standard +deviation 1, except that values whose magnitude is more than 2 standard +deviations from the mean are dropped and re-picked. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + + DefaultValuedAttr:$seed, + DefaultValuedAttr:$seed2 + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> { let summary = "Finds unique elements in a 1-D tensor."; @@ -9855,6 +10329,33 @@ https://www.tensorflow.org/xla/operation_semantics#gather TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> { + let summary = [{ +A pseudo-op to represent host-side computation in an XLA program. + }]; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs, + + StrArrayAttr:$ancestors, + TF_ShapeAttrArray:$shapes, + SymbolRefAttr:$shape_inference_graph, + StrAttr:$key, + DefaultValuedAttr:$cost_estimate_ns, + DefaultValuedAttr:$tpu_core + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> { let summary = "Wraps the XLA Sort operator, documented at"; @@ -9903,6 +10404,24 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> { + let summary = "An op to receive a tensor from the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_ShapeAttr:$shape, + StrAttr:$key + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>; +} + def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> { let summary = "Wraps the XLA Reduce operator, documented at"; @@ -9967,6 +10486,23 @@ i=0...N-1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> { + let summary = "An op to send a tensor to the host."; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$key + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> { let summary = [{ Computes the eigen decomposition of a batch of self-adjoint matrices @@ -10050,6 +10586,29 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> { + let summary = "An op that receives embeddng activations on the TPU."; + + let description = [{ +The TPU system performs the embedding lookups and aggregations. The results of +these aggregations are visible to the Tensorflow Graph as the outputs of a +_RecvTPUEmbeddingActivations Op. This op returns a list containing one +Tensor of activations per table specified in the model. + }]; + + let arguments = (ins + TF_VariantTensor:$deduplication_data, + + StrAttr:$config + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultSizeAttr num_tables = TF_DerivedResultSizeAttr<0>; +} + def TF__TPUCompileMlirOp : TF_Op<"_TPUCompileMlir", []> { let summary = [{ Compiles a computations for execution on one or more TPU devices. @@ -10085,3 +10644,44 @@ used to look up the program in the compilation cache. TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>; TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>; } + +def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> { + let summary = [{ +A placeholder op to receive values from a running XLA computation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + +def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> { + let summary = "A placeholder op to send values to a running XLA computation."; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs, + TF_StrTensor:$dynamic_key, + + StrAttr:$key, + I64Attr:$device_ordinal + ); + + let results = (outs); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; +} diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 80a2b1925e6..dbd8ab0fae2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -23,7 +23,7 @@ limitations under the License. #define TF_OP_BASE include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" //===----------------------------------------------------------------------===// @@ -70,6 +70,16 @@ class TF_OpIsBroadcastableToRes : And<[ "$_op.getOperand(" # opId # ").getType(), " "$_op.getResult(" # resId # ").getType())">]>; + +class TF_AllTypesMatchPred values> : + CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin.result #"}))">; + +class TF_AllTypesMatch names> : + PredOpTrait< + "all of {" # StrJoin.result # "} have dynamically equal types ", + TF_AllTypesMatchPred< + !foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>; + //===----------------------------------------------------------------------===// // TensorFlow op definitions //===----------------------------------------------------------------------===// @@ -129,9 +139,16 @@ def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>; def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>; def TF_Uint8 : UI<8>; +def TF_Uint8Tensor : TensorOf<[TF_Uint8]>; + def TF_Uint16 : UI<16>; +def TF_Uint16Tensor : TensorOf<[TF_Uint16]>; + def TF_Uint32 : UI<32>; +def TF_Uint32Tensor : TensorOf<[TF_Uint32]>; + def TF_Uint64 : UI<64>; +def TF_Uint64Tensor : TensorOf<[TF_Uint64]>; // Any unsigned integer type def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 1b915e3d5fc..2007824369c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" @@ -34,6 +35,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project @@ -108,47 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) { return !type || type.getRank() <= rank; } -// Returns true if the given pair of TensorFlow types can be cast to one -// another. In other words, a single run-time value is legal for both the types. -// For example, tensor<*xf32> and tensor<3xf32> are cast compatible. -static bool AreCastCompatible(Type a, Type b) { - if (TensorCastOp::areCastCompatible(a, b)) return true; - - // Resource types may optionally contain subtypes information that does not - // match. Check subtypes compatibility when possible, otherwise treat them as - // compatible. - auto a_or_element_type = getElementTypeOrSelf(a); - auto b_or_element_type = getElementTypeOrSelf(b); - - auto a_kind = a_or_element_type.getKind(); - auto b_kind = b_or_element_type.getKind(); - - if (a_kind == TensorFlowTypes::RESOURCE && - b_kind == TensorFlowTypes::RESOURCE) { - auto a_resource_type = a_or_element_type.dyn_cast(); - auto b_resource_type = b_or_element_type.dyn_cast(); - bool a_has_subtype = !a_resource_type.getSubtypes().empty(); - bool b_has_subtype = !b_resource_type.getSubtypes().empty(); - - if (!a_has_subtype || !b_has_subtype) return true; - - assert(a_resource_type.getSubtypes().size() <= 1 && - "Resource type must have at most one subtype"); - assert(b_resource_type.getSubtypes().size() <= 1 && - "Resource type must have at most one subtype"); - - return TensorCastOp::areCastCompatible( - a_resource_type.getSubtypes().front(), - b_resource_type.getSubtypes().front()); - } - - // Variant types may optionally contain subtypes information that need not - // match. It is also not possible to compare subtypes for compatibility as - // their interpretation depends on the ops operating on them. So, accept all - // pairs of variant types. - return a_kind == TensorFlowTypes::VARIANT && - b_kind == TensorFlowTypes::VARIANT; -} static bool IsUnknownDimOrRank(int64_t dim_or_rank) { return dim_or_rank == -1; @@ -494,6 +455,57 @@ LogicalResult FoldOperandsPermutation( return success(); } +//===----------------------------------------------------------------------===// +// Rewrite Pattern for removing trivial Arithmetic op. +//===----------------------------------------------------------------------===// + +namespace { +// Folder that returns LHS of an Arithmetic Op if the RHS is a constant +// known to be Identity (e.g X+0) +template ::value>::type * = nullptr> +OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, + ArrayRef operands) { + auto result_op_type = arithmetic_op.getResult().getType(); + auto lhs_type = arithmetic_op.x().getType().template cast(); + if (!result_op_type.template cast().hasStaticShape()) return {}; + + // We only handle non-broadcastable case. + if (result_op_type != lhs_type) { + return {}; + } + + // Mul and Div ops have identity value one while AddV2 and SubOp have identity + // value zero. + int identity = + (std::is_same::value || std::is_same::value); + + Type element_ty = lhs_type.getElementType(); + Attribute identity_attr; + if (auto ty = element_ty.template dyn_cast()) { + identity_attr = FloatAttr::get(ty, static_cast(identity)); + } else if (auto ty = element_ty.template dyn_cast()) { + identity_attr = IntegerAttr::get(ty, static_cast(identity)); + } else { + return {}; + } + + if (auto attr = operands[1].dyn_cast_or_null()) { + if (attr.isSplat() && attr.getSplatValue() == identity_attr) + return arithmetic_op.x(); + } + + bool is_symmetric = + (std::is_same::value || std::is_same::value); + if (auto attr = operands[0].dyn_cast_or_null()) { + if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr) + return arithmetic_op.y(); + } + return {}; +} +} // namespace + namespace { #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace @@ -525,6 +537,10 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult AddV2Op::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // AllOp //===----------------------------------------------------------------------===// @@ -927,20 +943,17 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, LogicalResult ConstOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - for (NamedAttribute named_attr : attributes) { - if (named_attr.first.strref() != "value") continue; - auto value = named_attr.second; - if (auto elem_attr = value.dyn_cast()) { - inferredReturnTypes.assign({elem_attr.getType()}); - return success(); - } - return emitOptionalError(location, - "attribute 'value' failed to satisfy constraint: " - "constant vector/tensor"); + auto value = attributes.get("value"); + if (!value) return emitOptionalError(location, "missing attribute 'value'"); + if (auto elem_attr = value.dyn_cast()) { + inferredReturnTypes.assign({elem_attr.getType()}); + return success(); } - return emitOptionalError(location, "missing attribute 'value'"); + return emitOptionalError(location, + "attribute 'value' failed to satisfy constraint: " + "constant vector/tensor"); } //===----------------------------------------------------------------------===// @@ -1271,6 +1284,10 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult DivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // DynamicStitchOp //===----------------------------------------------------------------------===// @@ -1355,7 +1372,7 @@ static LogicalResult Verify(DynamicStitchOp op) { auto expected_out_ty = RankedTensorType::get(expected_shape, out_ty.getElementType()); - if (!AreCastCompatible(out_ty, expected_out_ty)) { + if (!AreCastCompatible({out_ty, expected_out_ty})) { return op.emitOpError() << "has invalid output type; should be " "compatible with inferred type " << expected_out_ty; @@ -1381,6 +1398,43 @@ static LogicalResult Verify(EinsumOp op) { return success(); } +//===----------------------------------------------------------------------===// +// EmptyOp +//===----------------------------------------------------------------------===// + +OpFoldResult EmptyOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "empty op has one operand"); + + Attribute attr = operands.front(); + if (!attr) return {}; + + auto int_attr = attr.cast(); + SmallVector out_shape; + for (const auto val : int_attr.getValues()) { + out_shape.push_back(val); + } + + auto type = getResult().getType().cast(); + auto etype = type.getElementType(); + + // We can not fold if the result is not static. + if (!type.hasStaticShape()) return {}; + + if (auto float_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, float_type); + return DenseElementsAttr::get(out_type, + {APFloat(float_type.getFloatSemantics())}); + } + + if (auto int_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, etype); + APInt val(int_type.getWidth(), 0, int_type.getSignedness()); + return DenseElementsAttr::get(out_type, val); + } + + return {}; +} + //===----------------------------------------------------------------------===// // EmptyTensorListOp //===----------------------------------------------------------------------===// @@ -1552,7 +1606,7 @@ static ShapedType InferFillOpType(Value dims, Value value) { llvm::SmallVector shape; shape.reserve(dims_attr.getNumElements()); - for (const APInt &dim : dims_attr.getValues()) { + for (const APInt dim : dims_attr.getValues()) { shape.push_back(dim.getSExtValue()); } return RankedTensorType::get(shape, etype); @@ -1563,6 +1617,29 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); } +OpFoldResult FillOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "fill op has two operand"); + + auto value = operands[1].dyn_cast_or_null(); + if (!value) return {}; + + auto type = getType().cast(); + if (type.hasStaticShape()) + return DenseElementsAttr::get(type, value.getValue({})); + + auto dims = operands[0].dyn_cast_or_null(); + if (!dims) return {}; + + llvm::SmallVector shape; + shape.reserve(dims.getNumElements()); + for (const APInt dim : dims.getValues()) { + shape.push_back(dim.getSExtValue()); + } + type = RankedTensorType::get(shape, type.getElementType()); + + return DenseElementsAttr::get(type, value.getValue({})); +} + //===----------------------------------------------------------------------===// // FusedBatchNormGradOp //===----------------------------------------------------------------------===// @@ -1719,14 +1796,14 @@ static LogicalResult Verify(IfOp op) { for (unsigned i = 0; i < expectedNumInputs; ++i) { auto operandType = op.getOperand(i + 1).getType().cast(); auto thenInputType = thenFuncType.getInput(i).cast(); - if (!AreCastCompatible(operandType, thenInputType)) + if (!AreCastCompatible({operandType, thenInputType})) return op.emitError( llvm::formatv("then branch input type {0} is incompatible with " "operand type {1} at index {2}", thenInputType, operandType, i)); auto elseInputType = elseFuncType.getInput(i).cast(); - if (!AreCastCompatible(operandType, elseInputType)) + if (!AreCastCompatible({operandType, elseInputType})) return op.emitError( llvm::formatv("else branch input type {0} is incompatible with " "operand type {1} at index {2}", @@ -1734,7 +1811,7 @@ static LogicalResult Verify(IfOp op) { // If branches have incompatible input types that means that no tensor can // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible(thenInputType, elseInputType)) + if (!AreCastCompatible({thenInputType, elseInputType})) return op.emitError(llvm::formatv( "branches inputs have incompatible types {0} and {1} at index {2}", thenInputType, elseInputType, i)); @@ -1750,14 +1827,14 @@ static LogicalResult Verify(IfOp op) { for (unsigned i = 0; i < expectedNumResults; ++i) { auto resultType = op.getResult(i).getType().cast(); auto thenResultType = thenFuncType.getResult(i).cast(); - if (!AreCastCompatible(thenResultType, resultType)) + if (!AreCastCompatible({thenResultType, resultType})) return op.emitError( llvm::formatv("then branch result type {0} is incompatible with op " "result type {1} at index {2}", thenResultType, resultType, i)); auto elseResultType = elseFuncType.getResult(i).cast(); - if (!AreCastCompatible(elseResultType, resultType)) + if (!AreCastCompatible({elseResultType, resultType})) return op.emitError( llvm::formatv("else branch result type {0} is incompatible with op " "result type {1} at index {2}", @@ -1936,6 +2013,14 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { return success(); } +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// @@ -2904,6 +2989,10 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult SubOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + //===----------------------------------------------------------------------===// // SumOp //===----------------------------------------------------------------------===// @@ -3682,7 +3771,7 @@ static LogicalResult Verify(WhileOp op) { auto aType = a.second[idx]; auto bType = b.second[idx]; - if (!AreCastCompatible(aType, bType)) + if (!AreCastCompatible({aType, bType})) return op.emitError(llvm::formatv( "{0} type {1} is incompatible with {2} type {3} at index {4}", a.first, aType, b.first, bType, idx)); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 6efa26b3745..94b0c5f5e19 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -626,29 +626,6 @@ def TF_FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> { TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>; } -def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", []> { - let summary = "An op that receives embedding activations on the TPU."; - - let description = [{ -The TPU system performs the embedding lookups and aggregations specified by -the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The -results of these aggregations are visible to the Tensorflow Graph as the -outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing -one Tensor of activations per table specified in the model. There can be at -most one RecvTPUEmbeddingActivations op in the TPU graph. - }]; - - let arguments = (ins - StrAttr:$config - ); - - let results = (outs - Variadic:$outputs - ); - - TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>; -} - // Multiple variadic operands with different sizes are not supported by the // dialect generator, so we manually added the op. def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> { @@ -680,6 +657,65 @@ config: Serialized TPUEmbeddingConfiguration proto. TF_DerivedOperandSizeAttr NN = TF_DerivedOperandSizeAttr<1>; } +// Multiple variadic operands with different sizes are not supported by the +// dialect generator, so we manually added the op. +def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> { + let summary = "Performs gradient updates of embedding tables."; + + let description = [{ +The gradients argument is a TensorList having the same length and shapes as the +return value of _RecvTPUEmbeddingActivations, but contains gradients of the +model's loss with respect to the embedding activations. The embedding tables are +updated from these gradients via the optimizer specified in the +TPUEmbeddingConfiguration proto given to tpu.initialize_system. + +gradients: A TensorList of gradients with which to update embedding tables. +learning_rates: A TensorList of learning rates used for updating the embedding + tables via the optimizer. The length of the TensorList must be equal to the + number of dynamic learning rate tags specified in the + TPUEmbeddingConfiguration proto. +deduplication_data: A Tensor with type=DT_VARIANT containing the deduplication + data. The tensor is an XLA nested tuple containing N elements. Each + element of the nested tuple is a tuple of rank 1 tensors. Each tensor either + contains indices (DT_INT32) for embedding lookup or weights (DT_FLOAT) to + apply to the output of the embedding lookup operation. +config: Serialized TPUEmbeddingConfiguration proto. + }]; + + let arguments = (ins + Variadic:$gradients, + Variadic:$learning_rates, + TF_VariantTensor:$deduplication_data, + StrAttr:$config + ); + + TF_DerivedOperandSizeAttr NumTables = TF_DerivedOperandSizeAttr<0>; + TF_DerivedOperandSizeAttr NumLearningRateTags = TF_DerivedOperandSizeAttr<1>; +} + +// Updated the op description text from the auto-generated op definition. +def TF__RecvTPUEmbeddingDeduplicationDataOp : TF_Op<"_RecvTPUEmbeddingDeduplicationData", []> { + let summary = [{ +Receives deduplication data (indices and weights). + }]; + + let description = [{ +The deduplication data is a Tensor with type=DT_VARIANT. The tensor itself is an +XLA nested tuple containing N elements. Each element of the nested tuple is a +tuple of rank 1 tensors. Each tensor either contains indices (DT_INT32) for +embedding lookup or weights (DT_FLOAT) to apply to the output of the embedding +lookup operation. + }]; + + let arguments = (ins + StrAttr:$config + ); + + let results = (outs + TF_VariantTensor:$output + ); +} + def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { let summary = [{ An op which shards the input based on the given sharding attribute. @@ -741,4 +777,157 @@ Formats a string template using a list of tensors, pretty-printing tensor summar TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; } +//===----------------------------------------------------------------------===// +// tf.data ops +//===----------------------------------------------------------------------===// + +def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> { + let summary = [{ +Creates a dataset that batches `batch_size` elements from `input_dataset`. + }]; + + let description = [{ + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + I64Tensor:$batch_size, + I1Tensor:$drop_remainder, + + DefaultValuedAttr:$parallel_copy, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + +def TF_MapDatasetOp : TF_Op<"MapDataset", [NoSideEffect]> { + let summary = [{ + Creates a dataset that applies `f` to the outputs of `input_dataset`. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$use_inter_op_parallelism, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_MapAndBatchDatasetOp : TF_Op<"MapAndBatchDataset", [NoSideEffect]> { + let summary = "Creates a dataset that fuses mapping with batching."; + + let description = [{ +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + I64Tensor:$batch_size, + I64Tensor:$num_parallel_calls, + I1Tensor:$drop_remainder, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_ParallelMapDatasetOp : TF_Op<"ParallelMapDataset", [NoSideEffect]> { + let summary = [{ + Creates a dataset that applies `f` to the outputs of `input_dataset`. + }]; + + let description = [{ + Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes + up to `num_parallel_calls` copies of `f` in parallel. + }]; + + let arguments = (ins + TF_VariantTensor:$input_dataset, + Variadic:$other_arguments, + I32Tensor:$num_parallel_calls, + + SymbolRefAttr:$f, + Confined]>:$output_types, + Confined]>:$output_shapes, + DefaultValuedAttr:$use_inter_op_parallelism, + DefaultValuedAttr:$sloppy, + DefaultValuedAttr:$preserve_cardinality + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Targuments = TF_DerivedOperandTypeListAttr<1>; +} + +def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { + let summary = [{ + Creates a dataset that emits each dim-0 slice of `components` once. + }]; + + let arguments = (ins + Variadic:$components, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; +} + +// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once +// autogenerated op def matches. +def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { + let summary = "Updates specified rows 'i' with values 'v'."; + + let description = [{ +Computes `x[i, :] = v; return x`. + +Originally this function is mutative however for compilation we make this +operation create / operate on a copy of `x`. + }]; + + let arguments = (ins + TF_Tensor:$x, + I32Tensor:$i, + TF_Tensor:$v + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 6c3cd7fac92..d312e5e409b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -28,6 +28,134 @@ llvm::Optional> GetShape(mlir::Value value) { if (shaped_type.hasRank()) return shaped_type.getShape(); return llvm::None; } + +// Merges cast compatible shapes and returns a more refined shape. The two +// shapes are cast compatible if they have the same rank and at each dimension, +// either both have same size or one of them is dynamic. Returns false if the +// given shapes are not cast compatible. The refined shape is same or more +// precise than the two input shapes. +bool GetCastCompatibleShape(llvm::ArrayRef a_shape, + llvm::ArrayRef b_shape, + llvm::SmallVectorImpl* refined_shape) { + if (a_shape.size() != b_shape.size()) return false; + int64_t rank = a_shape.size(); + refined_shape->reserve(rank); + for (auto dims : llvm::zip(a_shape, b_shape)) { + int64_t dim1 = std::get<0>(dims); + int64_t dim2 = std::get<1>(dims); + + if (mlir::ShapedType::isDynamic(dim1)) { + refined_shape->push_back(dim2); + continue; + } + if (mlir::ShapedType::isDynamic(dim2)) { + refined_shape->push_back(dim1); + continue; + } + if (dim1 == dim2) { + refined_shape->push_back(dim1); + continue; + } + return false; + } + return true; +} + +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// +// The two types are considered cast compatible if they have dynamically equal +// shapes and element type. For element types that do not have subtypes, they +// must be equal. However for TensorFlow types such as Resource and Variant, +// that also have subtypes, we recursively check for subtype compatibilty for +// Resource types and assume all variant types are cast compatible. If either +// one of `a` or `b` have empty subtypes, they are considered cast compatible. +// +// The returned type is same or more precise than the input types. For example, +// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and +// tensor respectively, the returned type is tensor<2x4x?xf32>. +// +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a) { + // Fast path if everything is equal. + if (a == b) return b; + + auto a_tt = a.dyn_cast(); + auto b_tt = b.dyn_cast(); + + // If only one of a or b is a tensor type, they are incompatible. + if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; + + // For non-tensor types, we do not need to worry about shape and can return + // early. + if (!a_tt && !b_tt) { + // Remove ref types. + if (may_ignore_ref_type_a) { + if (auto ref_type = a.dyn_cast()) { + a = ref_type.RemoveRef(); + if (a == b) return a; + } + } + if (a.getKind() != b.getKind()) return nullptr; + + // If either is not a type that contain subtypes then the types are not cast + // compatible. + auto a_wst = a.dyn_cast(); + auto b_wst = b.dyn_cast(); + if (!a_wst || !b_wst) return nullptr; + + // For Variant types we are more permissive right now and accept all pairs + // of Variant types. If we are more constrainted and check compatibility of + // subtypes, we might reject valid graphs. + // TODO(prakalps): Variant doesn't have a subtype, we assign it + // one, so we should only assign it one when we know the subtype. Then we + // can be more constrained and check subtypes for cast compatibility as + // well. + if (a.isa()) return a; + + // For Resource types, we recursively check the subtypes for cast + // compatibility, if possible. Otherwise treat them as compatible. + auto a_wst_st = a_wst.GetSubtypes(); + auto b_wst_st = b_wst.GetSubtypes(); + if (a_wst_st.empty() || b_wst_st.empty()) return a; + if (a_wst_st.size() != b_wst_st.size()) return nullptr; + llvm::SmallVector refined_subtypes; + for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { + mlir::Type refined_st = + GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), + /*may_ignore_ref_type_a=*/false); + if (!refined_st) return nullptr; + refined_subtypes.push_back(refined_st.cast()); + } + + return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); + } + + // For tensor types, check compatibility of both element type and shape. + mlir::Type refined_element_ty = GetCastCompatibleType( + a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); + if (!refined_element_ty) return nullptr; + + if (!a_tt.hasRank() && !b_tt.hasRank()) { + return mlir::UnrankedTensorType::get(refined_element_ty); + } + if (!a_tt.hasRank()) { + return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); + } + if (!b_tt.hasRank()) { + return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); + } + + llvm::SmallVector refined_shape; + if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) + return nullptr; + + return mlir::RankedTensorType::get(refined_shape, refined_element_ty); +} } // namespace namespace mlir { @@ -224,44 +352,16 @@ bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs) { - // Fast path if everything is equal. - if (lhs == rhs) return true; + return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr; +} - // In TF all values are tensors. - auto lhs_tt = lhs.cast(); - auto rhs_tt = rhs.cast(); - - // Verify matching element types. These should be identical dynamically, - // so this allows for types not yet fully refined. - auto lhs_et = lhs_tt.getElementType(); - auto rhs_et = rhs_tt.getElementType(); - if (lhs_et == rhs_et) return true; - - // Remove ref types. - if (may_ignore_ref_type_lhs) { - if (auto ref_type = lhs_et.dyn_cast()) { - lhs_et = ref_type.RemoveRef(); - if (lhs_et == rhs_et) return true; - } - } - - if (lhs_et.getKind() != rhs_et.getKind()) return false; - - // If either is not type that contain subtypes then the element types don't - // match. - auto lhs_wst = lhs_et.dyn_cast(); - auto rhs_wst = rhs_et.dyn_cast(); - if (!lhs_wst || !rhs_wst) return false; - - // Consider the subtype recursively. - auto lhs_wst_st = lhs_wst.GetSubtypes(); - auto rhs_wst_st = rhs_wst.GetSubtypes(); - if (lhs_wst_st.empty() || rhs_wst_st.empty()) return true; - if (lhs_wst_st.size() != rhs_wst_st.size()) return false; - for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) { - if (!HasCompatibleElementTypes(std::get<0>(subtypes), - std::get<1>(subtypes))) - return false; +bool AreCastCompatible(ArrayRef types) { + Type common = types.front(); + for (auto type : types.drop_front()) { + Type refined_type = + GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false); + if (!refined_type) return false; + common = refined_type; } return true; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index d1e6a74a0c5..4c99aae4706 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -313,6 +313,12 @@ bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs); bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs = false); +// Returns true if all TensorFlow types can be cast to one +// another. In other words, a single run-time value is legal for both the types. +// For example, tensor<*xf32>, tensor and tensor<3xf32> are cast +// compatible. +bool AreCastCompatible(ArrayRef types); + } // end namespace TF } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir index 0111d4e4a89..743f0b43b69 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/annotate-parameter-replication.mlir @@ -10,18 +10,18 @@ module attributes {tf.versions = {producer = 888 : i32}} { %5:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf._F"(%arg0) : (tensor) -> tensor %3 = "tf.Identity"(%1) : (tensor) -> tensor - %4 = "tf_device.launch_func"(%ri_0, %3, %2) {func = @tpu0_func, device = ""} : (tensor, tensor, tensor) -> tensor + %4 = "tf_device.cluster_func"(%ri_0, %3, %2) {func = @_func, device = ""} : (tensor, tensor, tensor) -> tensor tf_device.return %4 : tensor } %6 = "tf._C"(%5#1) : (tensor) -> tensor return %6 : tensor } - // CHECK-LABEL: func @tpu0_func + // CHECK-LABEL: func @_func // CHECK-SAME: %[[ARG0:.*]]: tensor, // CHECK-SAME: %[[ARG1:.*]]: tensor {tf_device.is_same_data_across_replicas = true} // CHECK-SAME: %[[ARG2:.*]]: tensor) - func @tpu0_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } @@ -46,18 +46,18 @@ module attributes {tf.versions = {producer = 888 : i32}} { [%arg4, %arg5] as %ri_2: tensor>>) {_mirrored_variable_indices = [0, 2], n = 2 : i32} { %0 = "tf.ReadVariableOp"(%ri_0): (tensor>>) -> tensor %1 = "tf.ReadVariableOp"(%ri_1): (tensor>>) -> tensor - %2 = "tf_device.launch_func"(%0, %1, %ri_2) {func = @tpu0_func, device = ""} : (tensor, tensor, tensor>>) -> tensor + %2 = "tf_device.cluster_func"(%0, %1, %ri_2) {func = @_func, device = ""} : (tensor, tensor, tensor>>) -> tensor tf_device.return %2 : tensor } %4 = "tf._C"(%3#1) : (tensor) -> tensor return %4 : tensor } - // CHECK-LABEL: func @tpu0_func + // CHECK-LABEL: func @_func // CHECK-SAME: %[[ARG0:.*]]: tensor {tf_device.is_same_data_across_replicas = true}, // CHECK-SAME: %[[ARG1:.*]]: tensor, // CHECK-SAME: %[[ARG2:.*]]: tensor>> {tf_device.is_same_data_across_replicas = true} - func @tpu0_func(%arg0: tensor, %arg1: tensor, %arg2: tensor>>) -> tensor { + func @_func(%arg0: tensor, %arg1: tensor, %arg2: tensor>>) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } @@ -65,21 +65,21 @@ module attributes {tf.versions = {producer = 888 : i32}} { // ----- -// Tests that a non-replicated LaunchFuncOp is not annotated. +// Tests that a non-replicated ClusterFuncOp is not annotated. module attributes {tf.versions = {producer = 888 : i32}} { // CHECK-LABEL: func @do_not_annotate_without_replicate func @do_not_annotate_without_replicate(%arg0: tensor) -> tensor { %0 = "tf._A"(%arg0) : (tensor) -> tensor %1 = "tf._B"(%arg0) : (tensor) -> tensor - %2 = "tf_device.launch_func"(%0, %1) {func = @tpu0_func, device = ""} : (tensor, tensor) -> tensor + %2 = "tf_device.cluster_func"(%0, %1) {func = @_func, device = ""} : (tensor, tensor) -> tensor %3 = "tf._C"(%2) : (tensor) -> tensor return %3 : tensor } - // CHECK-LABEL: func @tpu0_func + // CHECK-LABEL: func @_func // CHECK-NOT: tf_device.is_same_data_across_replicas - func @tpu0_func(%arg0: tensor, %arg1: tensor) -> tensor { + func @_func(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf._D"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 18f8d5f4486..e05894dc266 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -471,3 +471,14 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor { // CHECK: return [[VAL0]] return %0 : tensor } + +// CHECK-LABEL: @foldFill +func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) { + %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tf.Const"() {value = dense<23.0> : tensor} : () -> tensor + // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} + %2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor) -> tensor<3x2x1xf32> + // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} + %3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor) -> tensor<*xf32> + return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir index 1866879c465..42ed55deeda 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir @@ -1,127 +1,120 @@ -// RUN: tf-opt %s -split-input-file -tf-device-cluster-outlining | FileCheck %s +// RUN: tf-opt %s -split-input-file -tf-device-cluster-outlining | FileCheck %s -dump-input-on-failure -// Tests simple case of a single `tf_device.launch`. +// Tests simple case of a single `tf_device.cluster`. -module { - // CHECK-LABEL: func @multiplelaunches - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @multiplelaunches(%arg0: tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island { - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) - %2 = "tf.A"(%arg0) : (tensor) -> tensor +// CHECK-LABEL: func @single_cluster +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @single_cluster(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func} - %3 = "tf_device.launch"() ( { - %4 = "tf.B"(%2) : (tensor) -> tensor - tf_device.return %4 : tensor - }) {device = "tpu0"} : () -> tensor + // CHECK: %[[CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) {func = @[[CLUSTER:.*]]} + %3 = "tf_device.cluster"() ( { + %4 = "tf.B"(%2) : (tensor) -> tensor + tf_device.return %4 : tensor + }) {} : () -> tensor - // CHECK: tf_executor.yield %[[C_OUTPUT]] - tf_executor.yield %3 : tensor - } - tf_executor.fetch %1#0 : tensor + // CHECK: tf_executor.yield %[[CLUSTER_OUTPUT]] + tf_executor.yield %3 : tensor } - return %0 : tensor + tf_executor.fetch %1#0 : tensor } - -// CHECK-LABEL: func @tpu0_func -// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor) -> tensor -// CHECK-SAME: sym_visibility = "private" -// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]]) -// CHECK: return %[[TPU0_FUNC_B_OUTPUT]] + return %0 : tensor } +// CHECK: func @[[CLUSTER]] +// CHECK-SAME: (%[[CLUSTER_ARG_0:[a-z0-9]*]]: tensor) -> tensor +// CHECK-SAME: sym_visibility = "private" +// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_ARG_0]]) +// CHECK: return %[[B_OUTPUT]] + // ----- -// Tests that multiple `tf_device.launch` that depend on each other are +// Tests that multiple `tf_device.cluster` that depend on each other are // correctly handled. -module { - // CHECK-LABEL: func @multiplelaunches - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @multiplelaunches(%arg0: tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island { - // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) - %2 = "tf.A"(%arg0) : (tensor) -> tensor +// CHECK-LABEL: func @multiple_clusters +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @multiple_clusters(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor - // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func} - %3 = "tf_device.launch"() ( { - %6 = "tf.B"(%2) : (tensor) -> tensor - tf_device.return %6 : tensor - }) {device = "tpu0"} : () -> tensor + // CHECK: %[[CLUSTER_0_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[A_OUTPUT]]) {func = @[[CLUSTER_0:.*]]} + %3 = "tf_device.cluster"() ( { + %6 = "tf.B"(%2) : (tensor) -> tensor + tf_device.return %6 : tensor + }) {} : () -> tensor - // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]]) - %4 = "tf.D"(%3) : (tensor) -> tensor + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[CLUSTER_0_OUTPUT]]) + %4 = "tf.D"(%3) : (tensor) -> tensor - // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[C_OUTPUT]], %[[D_OUTPUT]]) {device = "gpu0", func = @gpu0_func} - %5 = "tf_device.launch"() ( { - %6 = "tf.E"(%3) : (tensor) -> tensor - %7 = "tf.F"(%4, %6) : (tensor, tensor) -> tensor - tf_device.return %7 : tensor - }) {device = "gpu0"} : () -> tensor + // CHECK: %[[CLUSTER_1_OUTPUT:[0-9]*]] = "tf_device.cluster_func"(%[[CLUSTER_0_OUTPUT]], %[[D_OUTPUT]]) {func = @[[CLUSTER_1:.*]]} + %5 = "tf_device.cluster"() ( { + %6 = "tf.E"(%3) : (tensor) -> tensor + %7 = "tf.F"(%4, %6) : (tensor, tensor) -> tensor + tf_device.return %7 : tensor + }) {} : () -> tensor - // CHECK: tf_executor.yield %[[E_OUTPUT]] - tf_executor.yield %5 : tensor - } - tf_executor.fetch %1#0 : tensor + // CHECK: tf_executor.yield %[[CLUSTER_1_OUTPUT]] + tf_executor.yield %5 : tensor } - return %0 : tensor + tf_executor.fetch %1#0 : tensor } - -// CHECK-LABEL: func @tpu0_func -// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor) -> tensor -// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]]) -// CHECK: return %[[TPU0_FUNC_B_OUTPUT]] - -// CHECK-LABEL: func @gpu0_func -// CHECK-SAME: (%[[GPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor, %[[GPU0_FUNC_ARG_1:[a-z0-9]*]]: tensor) -> tensor -// CHECK: %[[GPU0_FUNC_E_OUTPUT:[0-9]*]] = "tf.E"(%[[GPU0_FUNC_ARG_0]]) -// CHECK: %[[GPU0_FUNC_F_OUTPUT:[0-9]*]] = "tf.F"(%[[GPU0_FUNC_ARG_1]], %[[GPU0_FUNC_E_OUTPUT]]) -// CHECK: return %[[GPU0_FUNC_F_OUTPUT]] + return %0 : tensor } +// CHECK: func @[[CLUSTER_0]] +// CHECK-SAME: (%[[CLUSTER_0_ARG_0:[a-z0-9]*]]: tensor) -> tensor +// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_0_ARG_0]]) +// CHECK: return %[[B_OUTPUT]] + +// CHECK: func @[[CLUSTER_1]] +// CHECK-SAME: (%[[CLUSTER_1_ARG_0:[a-z0-9]*]]: tensor, %[[CLUSTER_1_ARG_1:[a-z0-9]*]]: tensor) -> tensor +// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[CLUSTER_1_ARG_0]]) +// CHECK: %[[F_OUTPUT:[0-9]*]] = "tf.F"(%[[CLUSTER_1_ARG_1]], %[[E_OUTPUT]]) +// CHECK: return %[[F_OUTPUT]] + // ----- -// Tests outlining launches with no live-in values. +// Tests outlining clusters with no live-in values. -module { - // CHECK-LABEL: func @multiplelaunches - // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @multiplelaunches(%arg0: tensor) -> tensor { - %0 = tf_executor.graph { - %1:2 = tf_executor.island wraps - // CHECK: %[[A_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func} - "tf_device.launch"() ( { - %3 = "tf.A"() : () -> tensor - tf_device.return %3 : tensor - }) {device = "tpu0"} : () -> tensor - // CHECK: tf_executor.fetch %[[A_OUTPUT]] - tf_executor.fetch %1#0 : tensor - } - return %0 : tensor +// CHECK-LABEL: func @cluster_operands +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @cluster_operands(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island wraps + // CHECK: %[[CLUSTER_OUTPUT:[a-z0-9]*]], %{{.*}} = {{.*}} "tf_device.cluster_func"() {func = @[[CLUSTER:.*]]} + "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> tensor + tf_device.return %3 : tensor + }) {} : () -> tensor + // CHECK: tf_executor.fetch %[[CLUSTER_OUTPUT]] + tf_executor.fetch %1#0 : tensor } + return %0 : tensor +} -// CHECK-LABEL: func @tpu0_func +// CHECK: func @[[CLUSTER]] // CHECK-SAME: () -> tensor -// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"() -// CHECK: return %[[TPU0_FUNC_A_OUTPUT]] -} +// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"() +// CHECK: return %[[A_OUTPUT]] // ----- -// Tests launch attributes are copied over to launch_func. +// Tests cluster attributes are copied over to cluster_func. -module { - // CHECK-LABEL: func @launch_attrs - func @launch_attrs() -> tensor { - %0 = "tf_device.launch"() ( { - %1 = "tf.A"() : () -> tensor - tf_device.return %1 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - return %0 : tensor - } - -// CHECK: launch_attr = "launch_attr" +// CHECK-LABEL: func @cluster_attrs +func @cluster_attrs() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"() : () -> tensor + tf_device.return %1 : tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor + return %0 : tensor } + +// CHECK: "tf_device.cluster_func" +// CHECK-SAME: cluster_attr = "cluster_attr" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 2a34bbfacdc..bccb8923134 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -38,6 +38,56 @@ func @testPow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, ten return %0, %1, %2 : tensor<4xf32>, tensor<4xf32>, tensor<4xf32> } +// CHECK-LABEL: func @testEmpty32 +func @testEmpty32() -> (tensor<5xi32>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0> : tensor<5xi32>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xi32>) + return %1 : tensor<5xi32> +} + +// CHECK-LABEL: func @testEmpty64 +func @testEmpty64() -> (tensor<5xi64>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0> : tensor<5xi64>} + // CHECK: return [[VAL]] : tensor<5xi64> + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xi64>) + return %1 : tensor<5xi64> +} + +// CHECK-LABEL: func @testEmptyFloat +func @testEmptyFloat() -> (tensor<5xf64>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<5xf64>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xf64>) + return %1 : tensor<5xf64> +} + +// CHECK-LABEL: func @testEmptyf16 +func @testEmptyf16() -> (tensor<5xf16>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<5xf16>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xf16>) + return %1 : tensor<5xf16> +} + +// CHECK-LABEL: func @testEmptybf16 +func @testEmptybf16() -> (tensor<5xbf16>) { + %0 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: [[VAL:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<5xbf16>} + // CHECK: return [[VAL]] + %1 = "tf.Empty"(%0) : (tensor) -> (tensor<5xbf16>) + return %1 : tensor<5xbf16> +} + // CHECK-LABEL: func @testShapeN func @testShapeN(%arg0: tensor, %arg1: tensor<1x32x32x16xf32>, %arg2: tensor<*xf32>) -> (tensor<0xi64>, tensor<4xi64>, tensor<4xi64>, tensor) { @@ -251,3 +301,138 @@ func @testTensorListElementShape(%arg0: tensor>>) -> // CHECK-NEXT: return [[cst]] : tensor<2xi32> return %0: tensor<2xi32> } + +func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<0.0> : tensor<2x2xbf16> + %0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<0.0> : tensor<2x2xbf16> + %0 = "tf.Add"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialAddV2 + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialSub + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { + %cst = constant dense<0> : tensor<2x2xi8> + %0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> + return %0 : tensor<2x2xi8> + + // CHECK-LABEL: RemoveTrivialSubInt8 + // CHECK-NEXT: return %arg0 : tensor<2x2xi8> +} + +func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialMul + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialDiv + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<1.0> : tensor<2x2xbf16> + %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialMulInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { + %cst = constant dense<1> : tensor<2x2xi8> + %0 = "tf.Mul"(%cst, %arg0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> + return %0 : tensor<2x2xi8> + + // CHECK-LABEL: RemoveTrivialMulInt8 + // CHECK-NEXT: return %arg0 : tensor<2x2xi8> +} + +func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<1.0> : tensor<2x2xbf16> + %0 = "tf.Div"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: DivBf16LHS + // CHECK: tf.Div +} + +func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: DontRemoveTrivialAdd + // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> + // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: return %[[RESULT]] : tensor<2x2xf32> +} + +func @DontRemoveTrivialAdd2(%arg0: tensor, %arg1: tensor<2x2xf32>) -> tensor { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<2x2xf32>) -> tensor + %1 = "tf.AddV2"(%0, %cst) : (tensor , tensor<2x2xf32>) -> tensor + return %1 :tensor + + // CHECK-LABEL: DontRemoveTrivialAdd2 + // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> + // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<2x2xf32>) -> tensor + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor, tensor<2x2xf32>) -> tensor + // CHECK: return %[[RESULT]] : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt new file mode 100644 index 00000000000..1e640baa507 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tf-data-pipeline.pbtxt @@ -0,0 +1,256 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=BatchDatasetV2 -o - | FileCheck %s --dump-input-on-failure + +# CHECK-LABEL: func @main() -> tensor<*x!tf.variant> +# CHECK: %[[tensor_slice:.*]], %[[tensor_slice_control:.*]] = tf_executor.island wraps "tf.TensorSliceDataset" +# CHECK: %[[map_dataset:.*]], %[[map_dataset_control:.*]] = tf_executor.island wraps "tf.MapDataset"(%[[tensor_slice]] +# CHECK: %[[batch_dataset:.*]], %[[batch_dataset_control:.*]] = tf_executor.island wraps "tf.BatchDatasetV2"(%[[map_dataset]] + +node { + name: "tensors/normalize_tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000\000\000\001\000\000\000\002\000\000\000" + } + } + } +} +node { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "tensors/normalize_tensors/component_0" + attr { + key: "Toutput_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "MapDataset" + op: "MapDataset" + input: "TensorSliceDataset" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "__inference_Dataset_map__8" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "preserve_cardinality" + value { + b: false + } + } + attr { + key: "use_inter_op_parallelism" + value { + b: true + } + } +} +node { + name: "batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 5 + } + } + } +} +node { + name: "drop_remainder" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: false + } + } + } +} +node { + name: "BatchDatasetV2" + op: "BatchDatasetV2" + input: "MapDataset" + input: "batch_size" + input: "drop_remainder" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "parallel_copy" + value { + b: false + } + } +} +library { + function { + signature { + name: "__inference_Dataset_map__8" + input_arg { + name: "args_0" + type: DT_INT32 + } + output_arg { + name: "identity" + type: DT_INT32 + } + } + node_def { + name: "mul/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "mul" + op: "Mul" + input: "args_0" + input: "mul/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "mul:z:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_user_specified_name" + value { + s: "args_0" + } + } + } + } + } +} +versions { + producer: 134 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir index 1ab0195f33a..4b6600d3b16 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/stringescape.mlir @@ -11,7 +11,7 @@ func @main() { // CHECK-NEXT: value { // CHECK-NEXT: s: " 0\n\000\000" tf_executor.graph { - %0:2 = tf_executor.island wraps "tf.Empty"() {name = "dummy", dtype = "tfdtype$DT_INT32", value = "\200\n\00\00", listvalue = ["\20\0A"]} : () -> tensor<2xi32> + %0:2 = tf_executor.island wraps "tf.Placeholder"() {name = "dummy", dtype = "tfdtype$DT_INT32", value = "\200\n\00\00", listvalue = ["\20\0A"]} : () -> tensor<2xi32> tf_executor.fetch } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir index 4a09af84438..466c5adb0e5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/type_list_attr.mlir @@ -14,7 +14,7 @@ func @main() { // CHECK-NEXT: type: DT_FLOAT // CHECK-NEXT: } // CHECK-NEXT: } - %0:2 = tf_executor.island wraps "tf.Empty"() {name = "dummy", dtype = "tfdtype$DT_FLOAT", emptylist = [], typelist = ["tfdtype$DT_INT32", "tfdtype$DT_FLOAT"]} : () -> tensor<*xi32> + %0:2 = tf_executor.island wraps "tf.Placeholder"() {name = "dummy", dtype = "tfdtype$DT_FLOAT", emptylist = [], typelist = ["tfdtype$DT_INT32", "tfdtype$DT_FLOAT"]} : () -> tensor<*xi32> tf_executor.fetch } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index e7f4873594b..60663f4bd4a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -1,11 +1,11 @@ // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-resources-to-args | FileCheck %s -dump-input-on-failure // One resource, one read. The initial value of the resource is read. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD]]) // CHECK: return %[[PACK]] %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor @@ -19,8 +19,8 @@ func @main() -> tensor<2xf32> { // ----- // One resource, one write. The initial value of the resource is not read. -// CHECK-LABEL: func @main() -> (tensor {tf.resource_name = "x"}) -func @main() { +// CHECK-LABEL: func @main(%arg0: tensor) -> (tensor {tf.resource_name = "x"}) +func @main(%arg0: tensor) { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.AssignVariableOp" // CHECK: return %[[CONST]] @@ -33,12 +33,12 @@ func @main() { // ----- // One resource, two reads using different resource handles. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) - // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg0) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg1) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]]) // CHECK: return %[[PACK]] @@ -56,12 +56,12 @@ func @main() -> tensor<2xf32> { // ----- // Two resources, two reads using different resources. -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}, %arg1: tensor {tf.resource_name = "y"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}, %arg2: tensor {tf.resource_name = "y"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.VarHandleOp" // CHECK-NOT: "tf.ReadVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %[[CONST:[0-9]*]]) - // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg1) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %[[CONST:[0-9]*]]) + // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %arg2) // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%[[CONST]], %[[ADD2]]) // CHECK: return %[[PACK]] @@ -79,12 +79,12 @@ func @main() -> tensor<2xf32> { // ----- // One resource with read and write. The initial value of the resource is read. -// CHECK-LABEL: func @main(%arg0: tensor {tf.aliasing_output = 1 : i64, tf.resource_name = "x"}) -> (tensor<2xf32>, tensor) -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.aliasing_output = 1 : i64, tf.resource_name = "x"}) -> (tensor<2xf32>, tensor) +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.AssignVariableOp" - // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg0, %{{[0-9]*}}) + // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%arg1, %{{[0-9]*}}) // CHECK: %[[ADD2:[0-9]*]] = "tf.AddV2"(%[[ADD1]], %[[ADD1]]) - // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%arg0, %[[ADD2]]) + // CHECK: %[[PACK:[0-9]*]] = "tf.Pack"(%arg1, %[[ADD2]]) // CHECK: return %[[PACK]], %[[ADD1]] %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor @@ -102,8 +102,8 @@ func @main() -> tensor<2xf32> { // ----- // One resource with read and write. The initial value of the resource is not read. -// CHECK-LABEL: func @main() -> (tensor<2xf32>, tensor {tf.resource_name = "x"}) -func @main() -> tensor<2xf32> { +// CHECK-LABEL: func @main(%arg0: tensor) -> (tensor<2xf32>, tensor {tf.resource_name = "x"}) +func @main(%arg0: tensor) -> tensor<2xf32> { // CHECK-NOT: "tf.AssignVariableOp" // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<4.200000e+01> : tensor} // CHECK: %[[ADD1:[0-9]*]] = "tf.AddV2"(%[[CONST]], %[[CONST]]) @@ -138,8 +138,8 @@ func @cond_true(%arg0: tensor>>, %arg1: tensor) -> return %2 : tensor } -// CHECK-LABEL: func @main(%arg0: tensor {tf.resource_name = "x"}) -> tensor<2xf32> -func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outputs = "result"}} { +// CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> +func @main(%arg0: tensor) -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outputs = "result"}} { %0 = "tf.Const"() {value = dense<1.050000e+03> : tensor} : () -> tensor %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> %2 = "tf.ReadVariableOp"(%1) : (tensor>>) -> tensor @@ -157,10 +157,11 @@ func @main() -> tensor<2xf32> attributes {tf.entry_function = {inputs = "", outp // Tests resource passed in as an argument is not modified and not returned. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor -func @main(%arg0: tensor>>) { - %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor - // CHECK-NEXT: "tf.AddV2"(%[[ARG_0]], %[[ARG_0]]) +// CHECK-SAME: %arg0: tensor +// CHECK-SAME: %[[ARG_1:[a-z0-9]+]]: tensor +func @main(%arg0: tensor, %arg1: tensor>>) { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + // CHECK-NEXT: "tf.AddV2"(%[[ARG_1]], %[[ARG_1]]) %1 = "tf.AddV2"(%0, %0) : (tensor, tensor) -> tensor // CHECK-NEXT: return return @@ -171,9 +172,10 @@ func @main(%arg0: tensor>>) { // Tests resource passed in as an argument is modified but not returned. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> tensor -func @main(%arg0: tensor>>) { +func @main(%arg0: tensor>>, %arg1: tensor) { // CHECK-NEXT: %[[CONST:[a-z0-9]+]] = "tf.Const" %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () @@ -186,9 +188,10 @@ func @main(%arg0: tensor>>) { // Tests last resource assign is returned as a result. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 0 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> tensor -func @main(%arg0: tensor>>) { +func @main(%arg0: tensor>>, %arg1: tensor) { %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} @@ -204,9 +207,10 @@ func @main(%arg0: tensor>>) { // returns the same value prior. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> (tensor, tensor) -func @main(%arg0: tensor>>) -> tensor { +func @main(%arg0: tensor>>, %arg1: tensor) -> tensor { %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () // CHECK: %[[CONST:[a-z0-9]+]] = "tf.Const"() {value = dense<1.050000e+03> : tensor} @@ -221,9 +225,10 @@ func @main(%arg0: tensor>>) -> tensor { // Tests read interleaved between writes. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %{{[a-z0-9]+}}: tensor {tf.aliasing_output = 1 : i64} +// CHECK-SAME: %arg1: tensor // CHECK-SAME: -> (tensor, tensor) -func @main(%arg0: tensor>>) -> tensor { +func @main(%arg0: tensor>>, %arg1: tensor) -> tensor { // CHECK-NEXT: %[[CONST_0:[a-z0-9]+]] = "tf.Const"() {value = dense<4.200000e+01> : tensor} %0 = "tf.Const"() {value = dense<4.200000e+01> : tensor} : () -> tensor "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () @@ -271,7 +276,7 @@ func @main(%arg0: tensor>>, %arg1: tensor>>) -> tensor { %0 = "tf.VarIsInitializedOp"(%arg0) : (tensor>>) -> tensor + %1 = "tf.UnknownOp"(%arg0) : (tensor>>) -> tensor return %0 : tensor } @@ -323,7 +329,7 @@ func @main(%arg0: tensor>>) -> tensor { // Tests VarHandleOp has users that are not removed. func @main() -> tensor { - // expected-error@+1 {{expects no uses but used by operations: tf.UnknownOp, tf.VarIsInitializedOp}} + // expected-error@+1 {{expects users to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got [tf.UnknownOp, tf.VarIsInitializedOp]}} %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> %1 = "tf.VarIsInitializedOp"(%0) : (tensor>>) -> tensor %2 = "tf.UnknownOp"(%0) : (tensor>>) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir new file mode 100644 index 00000000000..8b8a070cfab --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_var_handles_to_args.mlir @@ -0,0 +1,59 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-var-handles-to-args | FileCheck %s -dump-input-on-failure + +// Tests main function with multiple blocks. + +// expected-error@+1 {{expects function 'main' to have 1 block, got 2}} +func @main() { + br ^bb1 +^bb1: + return +} + +// ----- + +// CHECK-LABEL: func @no_args +// CHECK-SAME: (%arg0: tensor {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @no_args() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + return +} + +// CHECK-LABEL: func @some_args +// CHECK-SAME: (%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @some_args(%arg0: tensor) { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor + return +} + +// CHECK-LABEL: func @unique_vars +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}, %arg1: tensor>> {tf.resource_name = "y"}) +// CHECK-NOT: "tf.VarHandleOp" +func @unique_vars() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "y"} : () -> tensor>> + return +} + +// CHECK-LABEL: func @duplicate_vars +// CHECK-SAME: (%arg0: tensor>> {tf.resource_name = "x"}) +// CHECK-NOT: "tf.VarHandleOp" +func @duplicate_vars() { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + return +} + +// CHECK-LABEL: func @duplicate_vars_with_users +// CHECK-SAME: (%arg0: tensor, %arg1: tensor>> {tf.resource_name = "x"}) +// CHECK: "tf.ReadVariableOp"(%arg1) +// CHECK: "tf.AssignAddVariableOp"(%arg1, %arg0) +// CHECK-NOT: "tf.VarHandleOp" +func @duplicate_vars_with_users(%arg0: tensor) { + %0 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor + %2 = "tf.VarHandleOp"() {container = "", shape = "tfshape$", shared_name = "x"} : () -> tensor>> + "tf.AssignAddVariableOp"(%2, %arg0) : (tensor>>, tensor) -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index cfbd112a7c2..8da252fc832 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -18,11 +18,10 @@ func @controls_per_replica() { return } -// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[ISLAND_0:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) -// CHECK: %[[ISLAND_1:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) -// CHECK: %[[ISLAND_2:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_0]], %[[ISLAND_1]]) +// CHECK: %[[CT_0:.*]] = tf_executor.ControlTrigger +// CHECK: %[[CT_1:.*]] = tf_executor.ControlTrigger +// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]]) +// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]]) // Tests devices are not remapped if no devices were defined in replicate. @@ -100,64 +99,45 @@ func @remap_device() { // CHECK: device = "/GPU:1" -// Tests unused per replica island are added as a control dependency to the -// island forwarding per replica results. -// CHECK-LABEL: func @unused_replica_control -// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) -func @unused_replica_control(%arg0: tensor, %arg1: tensor) { - %0 = tf_executor.graph { - %1 = tf_executor.ControlTrigger {} - %2:2 = tf_executor.island(%1) { - %3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor) {n = 2 : i32} { - %4 = "tf.opA"(%ri) : (tensor) -> tensor - %5 = "tf.opB"(%4) : (tensor) -> tensor - tf_device.return %4, %5 : tensor, tensor +// Tests replicate with control dependency output has each expanded replica +// control pinned to a sink island. +// CHECK-LABEL: func @replicate_control +func @replicate_control() { + tf_executor.graph { + %1 = tf_executor.island { + tf_device.replicate {n = 2 : i32} { + tf_device.return } - tf_executor.yield %3#0 : tensor + tf_executor.yield } - tf_executor.fetch %2#0 : tensor + tf_executor.fetch %1 : !tf_executor.control } return } -// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger -// CHECK: %[[ISLAND_0:[a-z_0-9]*]]:2, %{{.*}} = tf_executor.island(%[[CT]]) -// CHECK: %[[OP_A_0:[0-9]*]] = "tf.opA"(%[[ARG_0]]) -// CHECK: %[[OP_B_0:[0-9]*]] = "tf.opB"(%[[OP_A_0]]) -// CHECK: tf_executor.yield %[[OP_A_0]], %[[OP_B_0]] -// CHECK: %[[ISLAND_1:[a-z_0-9]*]]:2, %[[ISLAND_1_control:[a-z_0-9]*]] = tf_executor.island(%[[CT]]) -// CHECK: %[[OP_A_1:[0-9]*]] = "tf.opA"(%[[ARG_1]]) -// CHECK: %[[OP_B_1:[0-9]*]] = "tf.opB"(%[[OP_A_1]]) -// CHECK: tf_executor.yield %[[OP_A_1]], %[[OP_B_1]] -// CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island(%[[ISLAND_1_control]]) -// CHECK: tf_executor.yield %[[ISLAND_0]]#0 -// CHECK: tf_executor.fetch %[[ISLAND_2]] +// CHECK: %[[REPLICA_0:.*]] = tf_executor.island +// CHECK: %[[REPLICA_1:.*]] = tf_executor.island +// CHECK: %[[SINK:.*]] = tf_executor.island(%[[REPLICA_0]], %[[REPLICA_1]]) +// CHECK: tf_executor.fetch %[[SINK]] -// Tests replicate with dynamic result shapes uses its inner ops to determine -// types for sink island. -// CHECK-LABEL: func @replicate_body_result_types -func @replicate_body_result_types() { - "tf_executor.graph"() ( { - %0:3 = "tf_executor.island"() ( { - %1:2 = "tf_device.replicate"() ( { - ^bb0: - %a = "tf.opA"() : () -> tensor - "tf_device.return"(%a) : (tensor) -> () - }) {n = 2 : i32} : () -> (tensor<*xi1>, tensor<*xi1>) - "tf_executor.yield"(%1#0, %1#1) : (tensor<*xi1>, tensor<*xi1>) -> () - }) : () -> (tensor<*xi1>, tensor<*xi1>, !tf_executor.control) - "tf_executor.fetch"(%0#2) : (!tf_executor.control) -> () - }) : () -> () +// Tests replicate results are remapped correctly. +// CHECK-LABEL: func @replicate_result +func @replicate_result(%arg0: tensor, %arg1: tensor) { + %0:4 = tf_executor.graph { + %1:5 = tf_executor.island { + %2:4 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {n = 2 : i32} { + %3 = "tf.opA"(%arg2) : (tensor) -> tensor + %4 = "tf.opB"(%arg2) : (tensor) -> tensor + tf_device.return %3, %4 : tensor, tensor + } + tf_executor.yield %2#0, %2#1, %2#2, %2#3 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } return } -// CHECK: %[[ISLAND_0:.*]], %{{.*}} = tf_executor.island -// CHECK-NEXT: %[[OP_A_0:.*]] = "tf.opA"() -// CHECK-NEXT: tf_executor.yield %[[OP_A_0]] : tensor -// CHECK: %[[ISLAND_1:.*]], %{{.*}} = tf_executor.island -// CHECK-NEXT: %[[OP_A_1:.*]] = "tf.opA"() -// CHECK-NEXT: tf_executor.yield %[[OP_A_1]] : tensor -// CHECK: %[[ISLAND_2:.*]]:2, %[[ISLAND_2_CTRL:.*]] = tf_executor.island -// CHECK-NEXT: tf_executor.yield %[[ISLAND_0]], %[[ISLAND_1]] : tensor, tensor -// CHECK: tf_executor.fetch %[[ISLAND_2_CTRL]] : !tf_executor.control +// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 793c9a601cc..9e7358ab2f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -9,17 +9,17 @@ func @only_resource_load() -> tensor<*xi32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} - // CHECK: "tf_device.launch" + // CHECK: "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} + // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> tensor<*xi32> - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> return %1 : tensor<*xi32> } @@ -34,20 +34,20 @@ func @only_resource_store() -> tensor<*xi32> { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch" + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"() // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} + // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.SomeComputation"() : () -> (tensor<*xi32>) "tf.AssignVariableOp"(%0, %2) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %2 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - // CHECK: return %[[LAUNCH_RES]]#0 + // CHECK: return %[[CLUSTER_RES]]#0 return %1 : tensor<*xi32> } @@ -62,21 +62,21 @@ func @same_resource_load_and_store() -> tensor<*xi32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) {dtype = i32} - // CHECK: %[[LAUNCH_RES:[0-9]*]]:2 = "tf_device.launch" + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} + // CHECK: {cluster_attr = "cluster_attr"} // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) - // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[LAUNCH_RES]]#1) {dtype = i32} + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) {dtype = i32} - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - // CHECK: return %[[LAUNCH_RES]]#0 + // CHECK: return %[[CLUSTER_RES]]#0 return %1 : tensor<*xi32> } @@ -87,8 +87,8 @@ func @same_resource_load_and_store() -> tensor<*xi32> { // CHECK-LABEL: func @internal_resource func @internal_resource() -> tensor<*xi32> { - // CHECK: %[[LAUNCH_RES:[0-9]*]] = "tf_device.launch" - %0 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER_RES:[0-9]*]] = "tf_device.cluster" + %0 = "tf_device.cluster"() ( { // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> @@ -104,9 +104,9 @@ func @internal_resource() -> tensor<*xi32> { // CHECK: tf_device.return %[[COMPUTE_RES]] tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - // CHECK: return %[[LAUNCH_RES]] + // CHECK: return %[[CLUSTER_RES]] return %0 : tensor<*xi32> } @@ -120,12 +120,12 @@ func @lifting_failure() -> tensor<*xi32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> // expected-error @+1 {{has remaining resource inputs that can not be lifted}} - %1 = "tf_device.launch"() ( { + %1 = "tf_device.cluster"() ( { %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> %3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32> "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () tf_device.return %3 : tensor<*xi32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> return %1 : tensor<*xi32> } @@ -135,16 +135,16 @@ func @lifting_failure() -> tensor<*xi32> { // Tests that pass lifts resource reads/writes from a loop, and removed unused // resources. -// CHECK-LABEL: func @launch_with_loop -func @launch_with_loop() -> () { +// CHECK-LABEL: func @cluster_with_loop +func @cluster_with_loop() -> () { // CHECK: %[[COUNT:.*]] = "tf.Const"() {value = dense<10> : tensor} %0 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %unused = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]]:2 = "tf.While"(%[[COUNT]], %[[READ]]) %2:3 = "tf.While"(%0, %1, %unused) {body = @while_body, cond = @while_cond, device = "", is_stateless = false, @@ -153,9 +153,9 @@ func @launch_with_loop() -> () { -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]]#1 : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) // CHECK: return return } @@ -188,13 +188,13 @@ func @while_cond(%arg0: tensor, %arg1: tensor<*x!tf.resource>>, // Tests that pass lifts resource reads from loop condition. -// CHECK-LABEL: func @launch_with_loop -func @launch_with_loop() -> () { +// CHECK-LABEL: func @cluster_with_loop +func @cluster_with_loop() -> () { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, @@ -203,9 +203,9 @@ func @launch_with_loop() -> () { -> (tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]] : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) // CHECK: return return } @@ -230,13 +230,13 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass lifts read-only resource reads from loop, but does not add // assign after the loop. -// CHECK-LABEL: func @launch_with_loop -func @launch_with_loop() -> () { +// CHECK-LABEL: func @cluster_with_loop +func @cluster_with_loop() -> () { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, @@ -245,8 +245,8 @@ func @launch_with_loop() -> () { -> (tensor<*x!tf.resource>>) // CHECK: tf_device.return tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> () - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + // CHECK: {cluster_attr = "cluster_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () // CHECK-NOT: "tf.AssignVariableOp" // CHECK: return return @@ -267,15 +267,15 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass lifts resource reads from nested loops. -// CHECK-LABEL: func @launch_with_nested_loop -func @launch_with_nested_loop() -> () { +// CHECK-LABEL: func @cluster_with_nested_loop +func @cluster_with_nested_loop() -> () { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH_UNUSED:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %2:2 = "tf.While"(%0, %1) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, @@ -284,9 +284,9 @@ func @launch_with_nested_loop() -> () { -> (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]] : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH]], %[[CLUSTER]]) // CHECK: return return } @@ -330,15 +330,15 @@ func @while_cond1(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!t // Tests that pass reports error on non-aliasing while input/output resources. -func @launch_with_loop() -> () { +func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { @@ -355,15 +355,15 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass reports error on unsupported ops in loop body. -func @launch_with_loop() -> () { +func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { @@ -380,15 +380,15 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass reports error on unsupported ops in loop cond. -func @launch_with_loop() -> () { +func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %1 = "tf.While"(%0) { body = @while_body, cond = @while_cond, device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { @@ -408,16 +408,16 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // Tests that pass lifts resource reads from if branches. -// CHECK: func @launch_with_if(%[[ARG0:.*]]: tensor) -> tensor<4xf32> -func @launch_with_if(%arg0: tensor) -> tensor<4xf32> { +// CHECK: func @cluster_with_if(%[[ARG0:.*]]: tensor) -> tensor<4xf32> +func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) - // CHECK: %[[LAUNCH:.*]]:2 = "tf_device.launch"() - %2 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]]) %3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, output_shapes = [#tf.shape<>, #tf.shape<4>], is_stateless = false} @@ -428,10 +428,10 @@ func @launch_with_if(%arg0: tensor) -> tensor<4xf32> { %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1 tf_device.return %5 : tensor<4xf32> - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<4xf32> - // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]#1) - // CHECK: return %[[LAUNCH]]#0 + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 return %2 : tensor<4xf32> } // CHECK: func @if_then(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>) @@ -457,15 +457,15 @@ func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf. // Tests that pass lifts resource reads from nested if ops. -// CHECK: func @launch_with_nested_if(%[[ARG0:.*]]: tensor) -> tensor -func @launch_with_nested_if(%arg0: tensor) -> tensor { +// CHECK: func @cluster_with_nested_if(%[[ARG0:.*]]: tensor) -> tensor +func @cluster_with_nested_if(%arg0: tensor) -> tensor { // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) - // CHECK: %[[LAUNCH:.*]]:2 = "tf_device.launch"() - %2 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() + %2 = "tf_device.cluster"() ( { // CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]]) %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, output_shapes = [], is_stateless = false} @@ -476,10 +476,10 @@ func @launch_with_nested_if(%arg0: tensor) -> tensor { %5 = "tf.AddV2"(%4, %4) : (tensor, tensor) -> tensor // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]] tf_device.return %5 : tensor - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor, tensor) - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]#1) - // CHECK: return %[[LAUNCH]]#0 + // CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor, tensor) + }) {cluster_attr = "cluster_attr"} : () -> tensor + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1) + // CHECK: return %[[CLUSTER]]#0 return %2 : tensor } // CHECK: func @if_then(%[[TARG0:.*]]: tensor) @@ -520,10 +520,10 @@ func @inner_if_else(%arg0: tensor<*x!tf.resource>>) // Tests that the pass reports error for ambiguous resource aliasing. -func @launch_with_if(%arg0: tensor) -> tensor<4xf32> { +func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> - %2 = "tf_device.launch"() ( { + %2 = "tf_device.cluster"() ( { // expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}} %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, output_shapes = [#tf.shape<>], is_stateless = false} @@ -531,7 +531,7 @@ func @launch_with_if(%arg0: tensor) -> tensor<4xf32> { -> (tensor<*x!tf.resource>>) %4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource>>) -> tensor<4xf32> tf_device.return %4 : tensor<4xf32> - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> return %2 : tensor<4xf32> } func @if_then(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) @@ -548,15 +548,15 @@ func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf. // Tests that the pass lifts resources on two partitioned call ops sharing the // same callee. The lifting should clone the callee then modify the clone. -// CHECK-LABEL: @launch_with_partitioned_call -func @launch_with_partitioned_call() -> tensor { +// CHECK-LABEL: @cluster_with_partitioned_call +func @cluster_with_partitioned_call() -> tensor { // CHECK: %[[VH:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[CONST:.*]] = "tf.Const"() %1 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VH]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - %2 = "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + %2 = "tf_device.cluster"() ( { // CHECK: %[[PC0:.*]] = "tf.PartitionedCall"(%[[CONST]], %[[READ]], %[[CONST]]) // CHECK-SAME: f = @callee_resource_lifted %3 = "tf.PartitionedCall"(%1, %0, %1) {f = @callee, config = "", config_proto = "", executor_type = ""} @@ -569,7 +569,7 @@ func @launch_with_partitioned_call() -> tensor { %5 = "tf.AddV2"(%3, %4) : (tensor, tensor) -> tensor // CHECK: tf_device.return %[[ADD]] : tensor tf_device.return %5 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor return %2 : tensor } // CHECK: @callee(%[[OA0:.*]]: tensor, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor @@ -592,8 +592,8 @@ func @callee(%arg0: tensor, %arg1: tensor<*x!tf.resource>>, %ar // sharing the same callee. The lifting should clone the callee then modify the // clone. -// CHECK-LABEL: @launch_with_stateful_partitioned_call -func @launch_with_stateful_partitioned_call() -> () { +// CHECK-LABEL: @cluster_with_stateful_partitioned_call +func @cluster_with_stateful_partitioned_call() -> () { // CHECK: %[[VH0:.*]] = "tf.VarHandleOp"() %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> // CHECK: %[[VH1:.*]] = "tf.VarHandleOp"() @@ -602,8 +602,8 @@ func @launch_with_stateful_partitioned_call() -> () { %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor // CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]]) // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) - // CHECK: %[[LAUNCH:.*]] = "tf_device.launch"() - "tf_device.launch"() ( { + // CHECK: %[[CLUSTER:.*]] = "tf_device.cluster"() + "tf_device.cluster"() ( { // CHECK: %[[PC0:.*]] = "tf.StatefulPartitionedCall"(%[[READ0]], %[[READ1]], %[[CONST]]) // CHECK-SAME: f = @callee_resource_lifted %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} @@ -614,9 +614,9 @@ func @launch_with_stateful_partitioned_call() -> () { : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> // CHECK: tf_device.return %[[PC1]] : tensor tf_device.return - // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[LAUNCH]]) + // CHECK: {cluster_attr = "cluster_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> () + // CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]) return } // CHECK: @callee(%[[OA0:.*]]: tensor<*x!tf.resource>>, %[[OA1:.*]]: tensor<*x!tf.resource>>, %[[OA2:.*]]: tensor) -> tensor<*x!tf.resource>> @@ -637,17 +637,17 @@ func @callee(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.res // Tests that the pass reports error on called function that has resource output // which doesn't alias an input. -func @launch_with_stateful_partitioned_call() -> () { +func @cluster_with_stateful_partitioned_call() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> %2 = "tf.Const"() {value = dense<10.0> : tensor} : () -> tensor - "tf_device.launch"() ( { + "tf_device.cluster"() ( { %3 = "tf.StatefulPartitionedCall"(%0, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> %4 = "tf.StatefulPartitionedCall"(%3, %1, %2) {f = @callee, config = "", config_proto = "", executor_type = ""} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> tensor<*x!tf.resource>> tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } // expected-error @+1 {{unsupported function call: resource return value does not alias an input.}} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 1c979b96a9a..160bba94cfc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -390,4 +390,24 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { } return } + + // CHECK-LABEL: dont_update_for_ref + func @dont_update_for_ref() -> () { + // CHECK: () -> tensor<4x!tf.f32ref> + %11 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<4>, shared_name = ""} : () -> tensor<4x!tf.f32ref> + // CHECK: (tensor<4x!tf.f32ref>) -> tensor<4xf32> + %12 = "tf.Identity"(%11) {device = ""} : (tensor<4x!tf.f32ref>) -> tensor<4xf32> + // CHECK: (tensor<4xf32>) -> tensor<4xf32> + %13 = "tf.Neg"(%12) {device = ""} : (tensor<4xf32>) -> tensor<4xf32> + return + } + + // CHECK-LABEL: operand_as_shape + func @operand_as_shape(%18: tensor, %39: tensor<1x4x4x32xf32>) -> () { + %cst_5 = constant dense<512> : tensor + %19 = "tf.Pack"(%18, %cst_5) {N = 2 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + // CHECK: -> tensor<1x512xf32> + %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir b/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir index 282fa4953a5..b9c6e242e70 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/sink_constant.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @sink_const func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor) { - // Verify that the constant are sunk in the tf_device.launch region using them + // Verify that the constant are sunk in the tf_device.cluster region using them // and removed if no other use is left. // Only the 2.0 and 3.0 constants are removed, the 4.0 has a use in the return @@ -13,11 +13,11 @@ func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor) { %2 = "tf.Const"() {value = dense<4.000000e+00> : tensor} : () -> tensor %3 = tf_executor.graph { %res, %ctl = tf_executor.island { - %3 = "tf_device.launch"() ({ + %3 = "tf_device.cluster"() ({ // In the device region, check that the 3 constants are materialized and // remapped to the uses. - // CHECK: tf_device.launch + // CHECK: tf_device.cluster // CHECK-DAG: %[[CST2:.*]] = "tf.Const"{{.*}}2.0 // CHECK-DAG: %[[CST3:.*]] = "tf.Const"{{.*}}3.0 // CHECK-DAG: %[[CST4:.*]] = "tf.Const"{{.*}}4.0 @@ -31,7 +31,7 @@ func @sink_const(%arg0 : tensor<16xf32>) -> (tensor<16xf32>, tensor) { %5 = "tf.Mul"(%4, %1) : (tensor<16xf32>, tensor) -> tensor<16xf32> %6 = "tf.Mul"(%5, %2) : (tensor<16xf32>, tensor) -> tensor<16xf32> tf_device.return %6 : tensor<16xf32> - }) {device = "tpu0"} : () -> tensor<16xf32> + }) {} : () -> tensor<16xf32> tf_executor.yield %3 : tensor<16xf32> } tf_executor.fetch %res : tensor<16xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 118ce2e8645..ffa287e0e53 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -881,20 +881,29 @@ func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { - // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} - %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> - return %0 : tensor<64x64xbf16> +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOpUnrankedBand +func @testValidMatrixBandPartOpUnrankedBand(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<*xbf16> { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<*xbf16> + return %0 : tensor<*xbf16> +} + +// ----- + +// Test valid tf.MatrixBandPart +// CHECK-LABEL: func @testValidMatrixBandPartOpCompatibleDynamicShapes +func @testValidMatrixBandPartOpCompatibleDynamicShapes(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor } // ----- // Test invalid tf.MatrixBandPart -func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<*xbf16> { - // expected-error @+1 {{op failed to verify that all of {input, band} have same type}} - %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<*xbf16> - return %0 : tensor<*xbf16> +func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor, %arg2: tensor) -> tensor<64x64xbf16> { + // expected-error @+1 {{op failed to verify that all of {input, band} have dynamically equal types}} + %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor, tensor) -> tensor<64x64xbf16> + return %0 : tensor<64x64xbf16> } // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir new file mode 100644 index 00000000000..39f34caf259 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_map_and_batch.mlir @@ -0,0 +1,29 @@ +// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t + +module { +// CHECK-LABEL: fuse_map_and_batch +func @fuse_map_and_batch() -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} { + %0 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: %[[NPC:.*]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset" + %3 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant> + // CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], %[[BSIZE:.*]], %[[NPC]] + // CHECK-SAME: f = @"__inference_Dataset_map__80", + %4 = "tf.MapDataset"(%3) {device = "", + f = @"__inference_Dataset_map__80", + output_shapes = [#tf.shape<>], output_types = [i32], + preserve_cardinality = false, sloppy = false, + use_inter_op_parallelism = true} : (tensor<*x!tf.variant>) -> tensor + %5 = "tf.BatchDatasetV2"(%4, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor, tensor, tensor) -> tensor + return %5 : tensor +} + +func @"__inference_Dataset_map__80"(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir new file mode 100644 index 00000000000..70c5c220fe1 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_data_fuse_pmap_and_batch.mlir @@ -0,0 +1,29 @@ +// RUN: tf-opt -tf-standard-pipeline -tf-data-optimization %s -o %t && FileCheck %s --dump-input-on-failure < %t + +module { +// CHECK-LABEL: fuse_pmap_and_batch +func @fuse_pmap_and_batch() -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "BatchDatasetV2"}} { + %0 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %3 = "tf.Const"() {value = dense<12> : tensor} : () -> tensor + // CHECK: %[[TSLICE:.*]] = "tf.TensorSliceDataset" + %4 = "tf.TensorSliceDataset"(%2) {device = "", output_shapes = [#tf.shape<>]} : (tensor<3xi32>) -> tensor<*x!tf.variant> + // CHECK: "tf.MapAndBatchDataset"(%[[TSLICE]], + // CHECK-SAME: f = @"__inference_Dataset_map__80", + %5 = "tf.ParallelMapDataset"(%4, %3) {device = "", + f = @"__inference_Dataset_map__80", + output_shapes = [#tf.shape<>], output_types = [i32], + preserve_cardinality = false, sloppy = false, + use_inter_op_parallelism = true} : (tensor<*x!tf.variant>, tensor) -> tensor + %6 = "tf.BatchDatasetV2"(%5, %0, %1) {device = "", output_shapes = [#tf.shape<>], output_types = [i32], parallel_copy = false} : (tensor, tensor, tensor) -> tensor + return %6 : tensor +} + +func @"__inference_Dataset_map__80"(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {device = ""} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + return %2 : tensor<*xi32> +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index fbbbf05f116..6dceb00eefa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -2,7 +2,7 @@ // Test ops in cluster only have `_tpu_replicate` and `device` attributes -// removed when moved to a launch. +// removed when moved to a `tf_device.cluster`. // CHECK-LABEL: func @cluster_ops_removed_attrs func @cluster_ops_removed_attrs() { %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor @@ -18,9 +18,9 @@ func @cluster_ops_removed_attrs() { // Test TPUReplicateMetadata ops `name` and `num_replicas` attributes are not -// copied over to launch. -// CHECK-LABEL: func @launch_removed_metadata_attrs -func @launch_removed_metadata_attrs() { +// copied over to `tf_device.cluster`. +// CHECK-LABEL: func @removed_metadata_attrs +func @removed_metadata_attrs() { %0 = "tf.opA"() {_tpu_replicate = "replicate"} : () -> tensor "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", name = "name", num_replicas = 1, topology = "topology"} : () -> () return @@ -42,7 +42,7 @@ func @metadata_op_removed() { // Test ops in an island with the same `_tpu_replicate` attribute are merged -// under a launch. +// under a `tf_device.cluster`. // CHECK-LABEL: func @simple_island // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @simple_island(%arg0 : tensor) -> tensor { @@ -60,19 +60,19 @@ func @simple_island(%arg0 : tensor) -> tensor { } // CHECK: "tf.opB" -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH]] +// CHECK: tf_executor.yield %[[CLUSTER]] // Test ops in an island with the same `_tpu_replicate` attribute are merged -// under a launch, even when the associated TPUReplicateMetadata op is in a -// different island. +// under a `tf_device.cluster`, even when the associated TPUReplicateMetadata op +// is in a different island. // CHECK-LABEL: func @simple_island_separate_metadata // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @simple_island_separate_metadata(%arg0 : tensor) -> tensor { @@ -92,18 +92,18 @@ func @simple_island_separate_metadata(%arg0 : tensor) -> tensor { } // CHECK: "tf.opB" -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH]] +// CHECK: tf_executor.yield %[[CLUSTER]] // Test ops in multiple islands with the same `_tpu_replicate` attribute are -// merged under launch ops only within their respective island. +// merged under `tf_device.cluster` ops only within their respective island. // CHECK-LABEL: func @multiple_islands_separate_metadata // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @multiple_islands_separate_metadata(%arg0 : tensor) -> (tensor, tensor) { @@ -130,28 +130,28 @@ func @multiple_islands_separate_metadata(%arg0 : tensor) -> (tensor, ten // CHECK: %[[ISLAND_1:.*]], %[[ISLAND_1_control:.*]] = tf_executor.island { // CHECK: "tf.opB" -// CHECK: %[[LAUNCH_0:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_0:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH_0]] +// CHECK: tf_executor.yield %[[CLUSTER_0]] // CHECK: tf_executor.island { // CHECK: "tf.opE" -// CHECK: %[[LAUNCH_1:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_1:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ISLAND_1]]) // CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_0]]) // CHECK-NEXT: tf_device.return %[[OP_F]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_executor.yield %[[LAUNCH_1]] +// CHECK: tf_executor.yield %[[CLUSTER_1]] // Test ops in a function body with the same `_tpu_replicate` attribute are -// merged under a launch op. +// merged under a `tf_device.cluster` op. // CHECK-LABEL: func @ops_in_func_body // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) func @ops_in_func_body(%arg0 : tensor) -> (tensor, tensor, tensor) { @@ -167,7 +167,7 @@ func @ops_in_func_body(%arg0 : tensor) -> (tensor, tensor, tensor) -> (tensor, tensor, tensor) func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { @@ -193,7 +193,7 @@ func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { return %2 : tensor } -// CHECK: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_B]] @@ -201,8 +201,8 @@ func @nested_cluster_op_user(%arg0 : tensor) -> (tensor) { // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" // CHECK: tf_executor.graph { -// CHECK-NEXT: tf_executor.fetch %[[LAUNCH]]#0 -// CHECK: return %[[LAUNCH]]#1 +// CHECK-NEXT: tf_executor.fetch %[[CLUSTER]]#0 +// CHECK: return %[[CLUSTER]]#1 // Test nested op of a cluster with an operand from an op of the same cluster @@ -218,7 +218,7 @@ func @nested_cluster_op(%arg0 : tensor) -> (tensor) { return %1 : tensor } -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"() ( { // CHECK-NEXT: "tf.opC"(%[[OP_A]]) @@ -226,7 +226,7 @@ func @nested_cluster_op(%arg0 : tensor) -> (tensor) { // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: return %[[LAUNCH]] +// CHECK: return %[[CLUSTER]] // Test multiple clusters interleaved. @@ -242,21 +242,21 @@ func @interleaved_clusters(%arg0 : tensor) -> (tensor, tensor) { return %2, %3 : tensor, tensor } -// CHECK: %[[LAUNCH_0:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_0:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) // CHECK-NEXT: tf_device.return %[[OP_C]] // CHECK-NEXT: _tpu_replicate = "replicate_0" // CHECK-SAME: device = "device_0" // CHECK-SAME: topology = "topology_0" -// CHECK: %[[LAUNCH_1:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER_1:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_B]]) // CHECK-NEXT: tf_device.return %[[OP_D]] // CHECK-NEXT: _tpu_replicate = "replicate_1" // CHECK-SAME: device = "device_1" // CHECK-SAME: topology = "topology_1" -// CHECK: return %[[LAUNCH_0]], %[[LAUNCH_1]] +// CHECK: return %[[CLUSTER_0]], %[[CLUSTER_1]] // Test operands and results of ops of a cluster that are interleaved between @@ -276,14 +276,14 @@ func @interleaved_cluster_operands_results() { // CHECK: %[[OP_C:[0-9]*]] = "tf.opC" // CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]]) -// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]] = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA" // CHECK-NEXT: "tf.opF"(%[[OP_E]]) // CHECK-NEXT: tf_device.return %[[OP_A]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[LAUNCH]]) +// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[CLUSTER]]) // CHECK: "tf.opD"(%[[OP_B]]) @@ -306,24 +306,24 @@ func @one_replica(%arg0: tensor) -> tensor { // CHECK: %[[OP_C:[0-9]*]] = "tf.opC" // CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_C]]) -// CHECK: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( { +// CHECK: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( { // CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) // CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[OP_E]]) // CHECK-NEXT: tf_device.return %[[OP_A]], %[[OP_F]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[LAUNCH]]#0) +// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"(%[[CLUSTER]]#0) // CHECK: "tf.opD"(%[[OP_B]]) -// CHECK: return %[[LAUNCH]]#1 +// CHECK: return %[[CLUSTER]]#1 // CHECK-NOT: "tf.TPUReplicatedInput" // CHECK-NOT: "tf.TPUReplicatedOutput" // Test replication with replicated operands and replicated results. The cluster -// will be wrapped in a launch first and then by a replicate. TPUReplicatedInput -// and TPUReplicatedOutput nodes will be replaced by the replicate operands and -// results. +// will be wrapped in a `tf_device.cluster` first and then by a replicate. +// TPUReplicatedInput and TPUReplicatedOutput nodes will be replaced by the +// replicate operands and results. // CHECK-LABEL: func @replication // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor, %[[ARG_2:[a-z0-9]*]]: tensor) func @replication(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { @@ -347,18 +347,18 @@ func @replication(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> // CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor // CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor // CHECK-SAME: n = 2 : i32 -// CHECK-NEXT: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( { +// CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( { // CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]]) // CHECK: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]], %[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]]) // CHECK: tf_device.return %[[OP_D]], %[[OP_E]] // CHECK-NEXT: _tpu_replicate = "replicate" // CHECK-SAME: device = "device" // CHECK-SAME: topology = "topology" -// CHECK: tf_device.return %[[LAUNCH]]#0, %[[LAUNCH]]#1 +// CHECK: tf_device.return %[[CLUSTER]]#0, %[[CLUSTER]]#1 // CHECK: return %[[REPLICATE]]#0, %[[REPLICATE]]#3 -// Test `tf.TPUReplicatedInput` ops are sorted by their `index` attribute. +// Test TPUReplicatedInput ops are sorted by their `index` attribute. // Non-negative `index` should precede `index` of -1, and ordering of ops with // `index` of -1 does not matter. // CHECK-LABEL: func @sort_replicated_input @@ -452,7 +452,7 @@ func @mismatched_replicated_output() { // Test cluster that should be replicated where its outputs do not lead to a // TPUReplicatedOutput. func @missing_replicated_output() { - // expected-error@+1 {{requires output of tf_device.launch to lead to a 'tf.TPUReplicatedOutput' op}} + // expected-error@+1 {{requires output of tf_device.cluster to lead to a 'tf.TPUReplicatedOutput' op}} %0 = "tf.opA"() {_tpu_replicate = "replicate", device = "device", name = "name"} : () -> tensor %1 = "tf.opB"(%0) : (tensor) -> tensor "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () @@ -520,8 +520,10 @@ func @input_index_gaps(%arg0: tensor) { return } + // ----- + // Test that the `is_mirrored_variable` attribute is preserved in the // tf_device.replicate op. // CHECK-LABEL: func @mirrored_variables @@ -537,4 +539,3 @@ func @mirrored_variables(%arg0: tensor>>, %arg1: ten // CHECK: tf_device.replicate // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}} // CHECK-SAME: _mirrored_variable_indices = [1] - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir index ad2ebc08c1d..8b610e45b4e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_dynamic_padding_mapper.mlir @@ -10,7 +10,7 @@ // CHECK-LABEL: func @single_arg_single_shape func @single_arg_single_shape(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @func0, padding_map = ["\10\02\18\01"]} : (tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @func0, padding_map = ["\10\02\18\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -37,7 +37,7 @@ func @func0(%arg0: tensor, %arg1: tensor) { // CHECK-LABEL: func @single_arg_multiple_shapes func @single_arg_multiple_shapes(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor, [%arg0, %arg0] as %ri_2: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_0, %ri_1, %ri_2) {device = "", func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_1, %ri_2) {func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -69,7 +69,7 @@ func @func1(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @multiple_args func @multiple_args(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor, [%arg0, %arg0] as %ri_2: tensor, [%arg0, %arg0] as %ri_3: tensor, [%arg0, %arg0] as %ri_4: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {device = "", func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor, tensor, tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor, tensor, tensor, tensor, tensor) -> () tf_device.return } return @@ -90,7 +90,7 @@ func @func2(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tens // CHECK-LABEL: func @remap_indices func @remap_indices(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func3, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func3, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -111,7 +111,7 @@ func @func3(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // padding_arg_index: 1 // CHECK-LABEL: func @no_replicate func @no_replicate(%arg0: tensor) { - "tf_device.launch_func"(%arg0, %arg0, %arg0) {device = "", func = @func4, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%arg0, %arg0, %arg0) {func = @func4, padding_map = ["\10\02\18\01"]} : (tensor, tensor, tensor) -> () return } @@ -125,7 +125,7 @@ func @func4(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @no_padding_map func @no_padding_map(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func5} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func5} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -141,7 +141,7 @@ func @func5(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @empty_padding_map func @empty_padding_map(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func6, padding_map = []} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func6, padding_map = []} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -162,7 +162,7 @@ func @func6(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // CHECK-LABEL: func @unused_padding_map func @unused_padding_map(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - "tf_device.launch_func"(%ri_1) {device = "", func = @func7, padding_map = ["\10\02\18\01"]} : (tensor) -> () + "tf_device.cluster_func"(%ri_1) {func = @func7, padding_map = ["\10\02\18\01"]} : (tensor) -> () tf_device.return } return @@ -189,7 +189,7 @@ func @func7(%arg0: tensor) { func @missing_padding_arg(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor, [%arg0, %arg0] as %ri_2: tensor, [%arg0, %arg0] as %ri_3: tensor) {n = 2 : i32} { // expected-warning@+1 {{bad 'padding_map' attribute at index 0, unused padding_arg_index 1}} - "tf_device.launch_func"(%ri_0, %ri_2, %ri_3) {device = "", func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor, tensor, tensor) -> () + "tf_device.cluster_func"(%ri_0, %ri_2, %ri_3) {func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor, tensor, tensor) -> () tf_device.return } return @@ -206,8 +206,8 @@ func @func8(%arg0: tensor, %arg1: tensor, %arg2: tensor) { // Test bad padding map attribute (not an array). func @bad_padding_map() { tf_device.replicate {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op requires 'padding_map' array attribute}} - "tf_device.launch_func"() {device = "", func = @_func, padding_map = 0 : i32} : () -> () + // expected-error@+1 {{'tf_device.cluster_func' op requires 'padding_map' array attribute}} + "tf_device.cluster_func"() {func = @_func, padding_map = 0 : i32} : () -> () tf_device.return } return @@ -222,8 +222,8 @@ func @_func() { // Test bad padding map attribute (element in array is not a string). func @bad_padding_map_element() { tf_device.replicate {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, not a string}} - "tf_device.launch_func"() {device = "", func = @_func, padding_map = [0 : i32]} : () -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, not a string}} + "tf_device.cluster_func"() {func = @_func, padding_map = [0 : i32]} : () -> () tf_device.return } return @@ -238,8 +238,8 @@ func @_func() { // Test unparsable padding map. func @bad_padding_map_proto() { tf_device.replicate {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}} - "tf_device.launch_func"() {device = "", func = @_func, padding_map = ["z"]} : () -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}} + "tf_device.cluster_func"() {func = @_func, padding_map = ["z"]} : () -> () tf_device.return } return @@ -259,8 +259,8 @@ func @_func() { // padding_arg_index: 1 func @negative_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -280,8 +280,8 @@ func @_func(%arg0: tensor, %arg1: tensor) { // padding_arg_index: 1 func @bad_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -301,8 +301,8 @@ func @_func(%arg0: tensor, %arg1: tensor) { // padding_arg_index: -1 func @negative_padding_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor, tensor) -> () tf_device.return } return @@ -322,8 +322,8 @@ func @_func(%arg0: tensor, %arg1: tensor) { // padding_arg_index: 2 func @bad_padding_arg_index(%arg0: tensor) { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor, [%arg0, %arg0] as %ri_1: tensor) {n = 2 : i32} { - // expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}} - "tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor, tensor) -> () + // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}} + "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor, tensor) -> () tf_device.return } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir new file mode 100644 index 00000000000..eb67bdcc914 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -0,0 +1,81 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure + +// Tests extraction of a outside compiled ops at head of TPU computation. + +func @single_head_outside_compilation(%arg0 : tensor) -> () { + // CHECK: tf_device.launch + // CHECK: "tf.A" + // CHECK-NEXT: tf_device.return + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.B"() : () -> () + "tf.C"() : () -> () + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} + +// CHECK-LABEL: func @multiple_head_outside_compilation +func @multiple_head_outside_compilation(%arg0 : tensor) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.D"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + "tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + "tf.D"(%1) : (tensor) -> () + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} + +// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle +func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor) -> () { + // CHECK-NOT: tf_device.launch + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {} : (tensor) -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor) -> (tensor) + "tf.C"(%1) : (tensor) -> () + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} + +// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted +func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) + // CHECK-NEXT: tf_device.return %[[D_OUT]] + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.B" + // CHECK: "tf.C" + // CHECK: "tf.E" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %1 = "tf.B"() {} : () -> (tensor) + %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor) + %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor) -> (tensor) + %4 = "tf.E"(%3) {} : (tensor) -> (tensor) + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index b2e8f116827..3cb693ee571 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -3,12 +3,12 @@ // Tests that missing `_xla_outside_compilation` attribute value results in an error. func @missing_outside_compilation_attribute() -> () { - "tf_device.launch"() ( { + "tf_device.cluster"() ( { "tf.A"() : () -> () // expected-error@+1 {{attribute '_xla_outside_compilation' is empty}} "tf.B"() {_xla_outside_compilation = ""} : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -18,11 +18,11 @@ func @missing_outside_compilation_attribute() -> () { // CHECK-LABEL: func @no_outside_compilation func @no_outside_compilation() -> tensor { - %0 = "tf_device.launch"() ( { + %0 = "tf_device.cluster"() ( { %1 = "tf.A"() : () -> tensor %2 = "tf.B"(%1) : (tensor) -> tensor tf_device.return %2 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor return %0 : tensor } @@ -36,16 +36,15 @@ func @nodep_single_outside_compilation() -> () { // CHECK-NEXT: "tf_device.launch" // CHECK-NEXT: "tf.B" // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf_device.launch" + // CHECK: "tf_device.cluster" // CHECK-NEXT: "tf.A" - // CHECK: device = "tpu0" - // CHECK-SAME: launch_attr = "launch_attr" - "tf_device.launch"() ( { + // CHECK: cluster_attr = "cluster_attr" + "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.C"() : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -59,19 +58,18 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () { // CHECK-NEXT: "tf.C" // CHECK-NEXT: "tf.D" // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf_device.launch" + // CHECK: "tf_device.cluster" // CHECK-NEXT: "tf.A" // CHECK-NEXT: "tf.E" - // CHECK: device = "tpu0" - // CHECK-SAME: launch_attr = "launch_attr" - "tf_device.launch"() ( { + // CHECK: cluster_attr = "cluster_attr" + "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.E"() : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -80,15 +78,16 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () { // CHECK-LABEL: func @nodep_multiple_outside_compilation func @nodep_multiple_outside_compilation() -> () { // CHECK: "tf_device.parallel_execute" - // CHECK-COUNT-3: "tf_device.launch" - "tf_device.launch"() ( { + // CHECK-COUNT-2: "tf_device.launch" + // CHECK: "tf_device.cluster" + "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.C"() : () -> () "tf.D"() {_xla_outside_compilation = "cluster2"} : () -> () "tf.E"() : () -> () tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + }) {cluster_attr = "cluster_attr"} : () -> () return } @@ -100,17 +99,17 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor) -> tens // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster" // CHECK: tf_device.return - // CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]] + // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { - %2 = "tf_device.launch"() ( { + %2 = "tf_device.cluster"() ( { "tf.A"() : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () %3 = "tf.C"() : () -> tensor tf_device.return %3 : tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor + }) {cluster_attr = "cluster_attr"} : () -> tensor tf_device.return %2 : tensor } @@ -125,17 +124,17 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor) -> te // CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]]:2 = "tf_device.launch" + // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster" // CHECK: tf_device.return - // CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]] + // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]] %1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { - %2, %3 = "tf_device.launch"() ( { + %2, %3 = "tf_device.cluster"() ( { %4 = "tf.A"() : () -> tensor "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () %5 = "tf.C"() : () -> tensor tf_device.return %4, %5 : tensor, tensor - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor, tensor) + }) {cluster_attr = "cluster_attr"} : () -> (tensor, tensor) tf_device.return %2, %3 : tensor, tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 06d6c35e0a8..b8a48bbb379 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -5,7 +5,7 @@ // expected-error@+1 {{requires attribute 'tf.versions'}} module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_tf_versions() { - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -20,7 +20,7 @@ module attributes {tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_devices() { // expected-error@+1 {{error in fetching TPU compilation/execution devices: no TPU_SYSTEM devices found}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -30,13 +30,13 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `num_cores_per_replicas` +// Tests `tf_device.cluster_func` with missing `num_cores_per_replicas` // attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -46,12 +46,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `num_cores_per_replicas` attribute. +// Tests `tf_device.cluster_func` with bad `num_cores_per_replicas` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'num_cores_per_replica'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = "", step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -61,12 +61,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `step_marker_location` attribute. +// Tests `tf_device.cluster_func` with missing `step_marker_location` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_num_cores_per_replica() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -76,12 +76,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `step_marker_location` attribute. +// Tests `tf_device.cluster_func` with bad `step_marker_location` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_step_marker_location() { // expected-error@+1 {{requires attribute 'step_marker_location'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = 1, padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -91,12 +91,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable `step_marker_location` attribute. +// Tests `tf_device.cluster_func` with unparsable `step_marker_location` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_step_marker_location() { // expected-error@+1 {{bad 'step_marker_location' attribute with value 'test'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "test", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -106,12 +106,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `padding_map` attribute. +// Tests `tf_device.cluster_func` with missing `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -121,12 +121,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `padding_map` attribute. +// Tests `tf_device.cluster_func` with bad `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_padding_map() { // expected-error@+1 {{requires attribute 'padding_map'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = "", topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -136,12 +136,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `padding_map` attribute. +// Tests `tf_device.cluster_func` with bad element in `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_padding_map() { // expected-error@+1 {{bad 'padding_map' attribute at index 0, not a string}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [1], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -151,12 +151,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable element in `padding_map` attribute. +// Tests `tf_device.cluster_func` with unparsable element in `padding_map` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_padding_map() { // expected-error@+1 {{bad 'padding_map' attribute at index 0 with value 'test': failed to parse to tpu::PaddingMap}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["test"], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -166,12 +166,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `topology` attribute. +// Tests `tf_device.cluster_func` with missing `topology` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_topology() { // expected-error@+1 {{requires attribute 'topology'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -181,12 +181,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `topology` attribute. +// Tests `tf_device.cluster_func` with bad `topology` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_topology() { // expected-error@+1 {{requires attribute 'topology'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = 1 : i32, device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -196,12 +196,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with `topology` attribute resulting in device assignment error. +// Tests `tf_device.cluster_func` with `topology` attribute resulting in device assignment error. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @invalid_topology() { // expected-error@+1 {{error in fetching TPU compilation/execution devices}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "test", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -211,12 +211,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `device_assignment` attribute. +// Tests `tf_device.cluster_func` with missing `device_assignment` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_device_assignment() { // expected-error@+1 {{requires attribute 'device_assignment'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -226,12 +226,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `device_assignment` attribute. +// Tests `tf_device.cluster_func` with bad `device_assignment` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_device_assignment() { // expected-error@+1 {{requires attribute 'device_assignment'}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = "", input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -241,12 +241,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `device_assignment` attribute. +// Tests `tf_device.cluster_func` with bad element in `device_assignment` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_device_assignment() { // expected-error@+1 {{bad 'device_assignment' attribute at index 0, not an int}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [""], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -277,12 +277,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with `device_assignment` attribute resulting in device assignment error. +// Tests `tf_device.cluster_func` with `device_assignment` attribute resulting in device assignment error. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @invalid_device_assignment() { // expected-error@+1 {{error in fetching TPU compilation/execution devices}} - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "\0A\03\01\01\02\10\01\18\02\22\06\00\00\00\00\00\01", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () return } func @empty_func() { @@ -292,12 +292,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with missing `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], output_sharding_configuration = []} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], output_sharding_configuration = []} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -317,12 +317,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'input_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = "", output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -332,12 +332,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with mismatched `input_sharding_configuration` attribute size. +// Tests `tf_device.cluster_func` with mismatched `input_sharding_configuration` attribute size. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @mismatched_size_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -347,12 +347,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unsupported operand type. +// Tests `tf_device.cluster_func` with unsupported operand type. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unsupported_operand_type(%arg0: tensor) { // expected-error@+1 {{failed to determine operand type at index 0: Converting i2 to DataType}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -362,12 +362,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad element in `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0, not a string}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [1], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -377,12 +377,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable element in `input_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with unparsable element in `input_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_input_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'input_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["test"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -392,12 +392,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with missing `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with missing `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @missing_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_ENTRY", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -407,12 +407,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{requires attribute 'output_sharding_configuration'}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ""} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ""} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -422,12 +422,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with mismatched `output_sharding_configuration` attribute size. +// Tests `tf_device.cluster_func` with mismatched `output_sharding_configuration` attribute size. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @mismatched_size_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute, expected array attribute of size 1, got size 0}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = []} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = []} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -438,12 +438,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with bad element in `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with bad element in `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @bad_element_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0, not a string}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [1]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -453,12 +453,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with unparsable element in `output_sharding_configuration` attribute. +// Tests `tf_device.cluster_func` with unparsable element in `output_sharding_configuration` attribute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { func @unparsable_element_output_sharding_configuration(%arg0: tensor) { // expected-error@+1 {{bad 'output_sharding_configuration' attribute at index 0 with value 'test': failed to parse to xla::OpSharding}} - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["test"]} : (tensor) -> tensor return } func @empty_func(%arg0: tensor) -> tensor { @@ -468,7 +468,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests `tf_device.launch_func` with empty `step_marker_location` attribute +// Tests `tf_device.cluster_func` with empty `step_marker_location` attribute // defaults to `STEP_MARK_AT_ENTRY`. // // The expected TPUCompileMetadataProto is: @@ -478,7 +478,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @default_step_marker_location func @default_step_marker_location() { - "tf_device.launch_func"() {_tpu_replicate = "cluster0", device = "", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () + "tf_device.cluster_func"() {_tpu_replicate = "cluster0", func = @empty_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = [], output_sharding_configuration = []} : () -> () // CHECK: metadata // CHECK-SAME: num_replicas: 1 // CHECK-SAME: num_cores_per_replica: 1 @@ -497,7 +497,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @unranked_shape_arg func @unranked_shape_arg(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>) -> tensor<*xi32> // CHECK: metadata // CHECK-SAME: shape {\0A unknown_rank: true @@ -515,7 +515,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @partial_shape_arg func @partial_shape_arg(%arg0: tensor) -> tensor { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape {\0A dim {\0A size: -1\0A }\0A dim {\0A size: -1\0A }\0A dim {\0A size: 3\0A }\0A } @@ -546,7 +546,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @static_shape_arg func @static_shape_arg(%arg0: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<1x2x3xi32>) -> tensor<1x2x3xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: shape @@ -571,7 +571,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @resource_arg func @resource_arg(%arg0: tensor<*x!tf.resource>) -> tensor<*x!tf.resource> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*x!tf.resource>) -> tensor<*x!tf.resource> // CHECK: metadata // CHECK: dtype: DT_RESOURCE // CHECK-SAME: kind: VARIABLE @@ -590,7 +590,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @parameter_arg func @parameter_arg(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xf32>) -> tensor<*xf32> // CHECK: metadata // CHECK: dtype: DT_FLOAT // CHECK-SAME: kind: PARAMETER @@ -614,7 +614,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests metadata is populated correctly based on launch_func op and attributes. +// Tests metadata is populated correctly based on cluster_func op and attributes. // // The expected TPUCompileMetadataProto is: // args { @@ -650,7 +650,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @metadata func @metadata(%arg0: tensor<8xi32>) -> tensor<8xi32> { - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> // CHECK: metadata // CHECK-SAME: args // CHECK-SAME: dtype: DT_INT32 @@ -694,7 +694,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NOT: "tf.Shape"(%[[ARG_3]]) // CHECK: %[[ARG_0_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_0]]) // CHECK: %[[ARG_2_SHAPE:[0-9]*]] = "tf.Shape"(%[[ARG_2]]) - %0 = "tf_device.launch_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", device = "", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0, %arg1, %arg2, %arg3) {_tpu_replicate = "cluster0", func = @_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<*xi32>, tensor<8xi32>, tensor<*xi32>, tensor<8xi32>) -> tensor<8xi32> // CHECK: "tf._TPUCompileMlir"(%[[ARG_0_SHAPE]], %[[ARG_2_SHAPE]]) return %0: tensor<8xi32> @@ -706,16 +706,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.launch_func` on TPU with single input and +// Tests simple case of `tf_device.cluster_func` on TPU with single input and // single output. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @single_tpu_launch_func - func @single_tpu_launch_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @single_tpu_cluster_func + func @single_tpu_cluster_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -747,12 +747,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.launch_func` on TPU with replication. +// Tests simple case of `tf_device.cluster_func` on TPU with replication. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { - // CHECK-LABEL: func @replicated_tpu_launch_func + // CHECK-LABEL: func @replicated_tpu_cluster_func // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) - func @replicated_tpu_launch_func(%arg0: tensor) -> tensor { + func @replicated_tpu_cluster_func(%arg0: tensor) -> tensor { // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" %0 = "tf.A"(%arg0) : (tensor) -> tensor @@ -775,7 +775,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: tf_device.return %[[EXECUTE_OUTPUT]] tf_device.return %2 : tensor @@ -796,15 +796,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests that launch_func without _tpu_replicate attribute is ignored. +// Tests that cluster_func without _tpu_replicate attribute is ignored. module attributes {tf.versions = {producer = 888 : i32}} { - // CHECK-LABEL: func @single_gpu_launch_func - func @single_gpu_launch_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @single_gpu_cluster_func + func @single_gpu_cluster_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor - %1 = "tf_device.launch_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor - // CHECK: tf_device.launch_func + %1 = "tf_device.cluster_func"(%0) {device = "gpu0", func = @gpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + // CHECK: tf_device.cluster_func // CHECK-SAME: device = "gpu0" // CHECK-SAME: func = @gpu0_func // CHECK-SAME: num_cores_per_replica = 1 @@ -823,7 +823,7 @@ module attributes {tf.versions = {producer = 888 : i32}} { // ----- -// Tests of `tf_device.launch_func` on TPU with nested function calls. +// Tests of `tf_device.cluster_func` on TPU with nested function calls. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @with_nested_func @@ -831,7 +831,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -871,7 +871,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests of `tf_device.launch_func` on TPU with referenced function that's not +// Tests of `tf_device.cluster_func` on TPU with referenced function that's not // via a standard call op. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { @@ -880,7 +880,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -916,7 +916,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests rewriting `tf_device.launch_func` on TPU with a chain of referenced +// Tests rewriting `tf_device.cluster_func` on TPU with a chain of referenced // functions. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { @@ -925,7 +925,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -969,7 +969,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests rewriting `tf_device.launch_func` on TPU with multiple calls to same +// Tests rewriting `tf_device.cluster_func` on TPU with multiple calls to same // function. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { @@ -978,7 +978,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1017,15 +1017,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests multiple `tf_device.launch_func` on TPU with different computation. +// Tests multiple `tf_device.cluster_func` on TPU with different computation. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @multiple_launch_different_func - func @multiple_launch_different_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @multiple_cluster_different_func + func @multiple_cluster_different_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1039,7 +1039,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) @@ -1073,15 +1073,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests multiple `tf_device.launch_func` on TPU with same computation. +// Tests multiple `tf_device.cluster_func` on TPU with same computation. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @multiple_launch_same_func - func @multiple_launch_same_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @multiple_cluster_same_func + func @multiple_cluster_same_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1095,7 +1095,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) - %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %2 = "tf_device.cluster_func"(%1) {_tpu_replicate = "cluster1", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]]) @@ -1128,12 +1128,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ArrayAttr and DictionaryAttr. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { - // CHECK-LABEL: func @single_tpu_launch_func - func @single_tpu_launch_func(%arg0: tensor) -> tensor { + // CHECK-LABEL: func @single_tpu_cluster_func + func @single_tpu_cluster_func(%arg0: tensor) -> tensor { %0 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]]) @@ -1203,7 +1203,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.TPUCompileSucceededAssert" // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" - %1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + %1 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor %compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor %compile_result2 = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor @@ -1222,6 +1222,41 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- +// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { + // CHECK-LABEL: func @replicated_parallel_tpu_cluster_func + func @replicated_parallel_tpu_cluster_func(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + // CHECK: "tf._TPUCompileMlir" + // CHECK: "tf.TPUCompileSucceededAssert" + // CHECK: "tf_device.parallel_execute" + // CHECK: "tf.TPUExecute" + %3 = "tf_device.parallel_execute"() ( { + "tf.D"() : () -> () + tf_device.return + }, { + %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor + + tf_device.return %4 : tensor + }) : () -> (tensor) + tf_device.return %3 : tensor + } + %2 = "tf.C"(%1#1) : (tensor) -> tensor + return %2 : tensor + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + // Tests devices are set properly for non replicated model parallelism. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { @@ -1244,7 +1279,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: "tf.TPUExecute" // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1" - %0 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> return %0 : tensor<8xi32> } func @tpu0_func(%arg0: tensor<8xi32>) -> tensor<8xi32> { @@ -1309,7 +1344,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: "tf.TPUExecute" // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1344,7 +1379,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute"(%[[RI_1]], %[[RI_2]], %[[COMPILE]]#2) // CHECK: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.launch_func"(%ri, %ri2, %ri3) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri, %ri2, %ri3) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "", ""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>, tensor<*xi1>, tensor<*xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1382,7 +1417,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: "tf_device.launch" // CHECK-NEXT: "tf.TPUExecute" // CHECK: device = "TPU_REPLICATED_CORE_1" - %1 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> tf_device.return %1 : tensor<8xi32> } return %0#0, %0#1 : tensor<8xi32>, tensor<8xi32> @@ -1420,7 +1455,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] // CHECK: device = "TPU_REPLICATED_CORE_1" - %1, %2 = "tf_device.launch_func"(%ri) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<8xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1487,7 +1522,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] // CHECK: device = "TPU_REPLICATED_CORE_1" - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1555,7 +1590,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONST_CONCAT_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2 - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1598,7 +1633,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1638,7 +1673,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func @uneven_output_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect sharding format for outputs. Number of tiled outputs(4) must match the number of logical devices(2)}} - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = [""], topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["", ""], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1744,7 +1779,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1851,7 +1886,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_2_OUT]]#1, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -1935,7 +1970,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#3, %[[PARALLEL_EXECUTE_OUTPUT]]#4 // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -2020,7 +2055,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[LAUNCH_3_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_1_OUT]]#0, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> @@ -2104,7 +2139,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK: %[[CONCAT2_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT2_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#2, %[[PARALLEL_EXECUTE_OUTPUT]]#0 // CHECK: %[[CONST_CONCAT3_DIM:[0-9]*]] = "tf.Const"() // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[CONCAT_OUTPUT]], %[[CONCAT2_OUTPUT]] - %1, %2 = "tf_device.launch_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "", padding_map = [""], topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\02\02\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\02\02\22\04\03\02\01\00", "\08\01\1A\01\01\22\01\00"]} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index 2c49c2060f1..fff1240a121 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -1,10 +1,10 @@ // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-sharding-identification | FileCheck %s --dump-input=fail -// Tests empty launch func. Empty input/output sharding configuration +// Tests empty cluster func. Empty input/output sharding configuration // attributes must be added. -// CHECK-LABEL: func @check_sharding_attrs_exists_for_empty_launch_func -func @check_sharding_attrs_exists_for_empty_launch_func() { - "tf_device.launch_func"() {device = "", func = @empty_func, step_marker_location = ""} : () -> () +// CHECK-LABEL: func @check_sharding_attrs_exists_for_empty_cluster_func +func @check_sharding_attrs_exists_for_empty_cluster_func() { + "tf_device.cluster_func"() {func = @empty_func, step_marker_location = ""} : () -> () // CHECK: input_sharding_configuration = [] // CHECK: output_sharding_configuration = [] return @@ -21,7 +21,7 @@ func @empty_func() { // gets default maximal(0) sharding configuration. // CHECK-LABEL: func @check_default_sharding_for_block_arg_inputs_outputs func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>) { - "tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () + "tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () // CHECK: input_sharding_configuration // CHECK-SAME: ["\08\01\1A\01\01\22\01\00"] // CHECK: output_sharding_configuration @@ -42,7 +42,7 @@ func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { // default maximal(0) sharding configuration. // CHECK-LABEL: func @check_default_sharding_for_inputs_outputs func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) { - "tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () + "tf_device.cluster_func"(%arg0) {func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> () // CHECK: input_sharding_configuration // CHECK-SAME: ["\08\01\1A\01\01\22\01\00"] // CHECK: output_sharding_configuration @@ -63,7 +63,7 @@ func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> { // Tests with a input arg connected to XlaSharding op. // CHECK-LABEL: func @check_sharding_for_input_correctly_identified func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) { - "tf_device.launch_func"(%arg0) {device = "", func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> () + "tf_device.cluster_func"(%arg0) {func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> () // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03"] // CHECK: output_sharding_configuration @@ -85,7 +85,7 @@ func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { // Tests with sharding is correctly parsed for multiple inputs/outputs. // CHECK-LABEL: func @check_sharding_for_multiple_inputs_outputs func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -110,7 +110,7 @@ func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<* // Tests with input sharding following an identity op. // CHECK-LABEL: func @check_sharding_after_identity func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_identity, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding_after_identity, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -136,7 +136,7 @@ func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1 // Tests with input sharding following a ReadVariable op. // CHECK-LABEL: func @check_sharding_after_read_variable func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_read_variable, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding_after_read_variable, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -164,7 +164,7 @@ func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_cast, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_sharding_after_cast, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration @@ -191,7 +191,7 @@ func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) - // Tests that input sharding inside a functional op is parsed correctly. // CHECK-LABEL: func @check_sharding_inside_functional_op func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { - "tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_device_training_loop, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) + "tf_device.cluster_func"(%arg0, %arg1) {func = @func_with_device_training_loop, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) // CHECK: input_sharding_configuration // CHECK-SAME: ["\01\02\03", "\04\05\06"] // CHECK: output_sharding_configuration diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 01c30eabd35..fb3ecfde771 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -36,7 +36,7 @@ namespace { constexpr char kReplicationAttr[] = "tf_device.is_same_data_across_replicas"; constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices"; -// Analyzes the inputs to LaunchFuncOps in the module, and annotates their +// Analyzes the inputs to ClusterFuncOps in the module, and annotates their // invoked functions whether each input has the same data across replicas. struct AnnotateParameterReplication : public PassWrapper(); + m.walk([&](tf_device::ClusterFuncOp cluster_func) { + auto replicate = cluster_func.getParentOfType(); if (!replicate) return; auto mirrored_variable_indices_attr = replicate.getAttrOfType(kMirroredVariableIndicesAttr); @@ -69,8 +69,8 @@ void AnnotateParameterReplication::runOnOperation() { mirrored_index.cast().getInt()); } } - auto func = llvm::cast(m.lookupSymbol(launch_func.func())); - for (auto entry : llvm::enumerate(launch_func.getOperands())) { + auto func = llvm::cast(m.lookupSymbol(cluster_func.func())); + for (auto entry : llvm::enumerate(cluster_func.getOperands())) { auto operand = SkipIdentityAndReadVariable(entry.value()); auto block_arg = operand.dyn_cast(); if (block_arg && block_arg.getOwner() == &replicate.GetBody()) { @@ -98,7 +98,7 @@ CreateAnnotateParameterReplicationPass() { static PassRegistration pass( "tf-annotate-parameter-replication", - "Annotate whether a LaunchFuncOp's parameters have the same data across " + "Annotate whether a ClusterFuncOp's parameters have the same data across " "replicas."); } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index fc1622b93e9..a01769bc395 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -30,9 +30,10 @@ namespace { void EnableLogging(PassManager *pm) { // Print the whole module after each pass, which requires disabling // multi-threading as well. - pm->disableMultithreading(); + pm->getContext()->disableMultithreading(); pm->enableIRPrinting(std::make_unique( /*print_module_scope=*/true)); + pm->enableTiming(std::make_unique()); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index aa4c071abdf..886bd5b5b65 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This pass outlines regions of `tf_device.launch` into functions and replaces -// `tf_device.launch` with equivalent `tf_device.launch_func` operations. +// This pass outlines regions of `tf_device.cluster` into functions and replaces +// `tf_device.cluster` with equivalent `tf_device.cluster_func` operations. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -35,7 +35,6 @@ namespace TFDevice { namespace { -constexpr char kDeviceAttr[] = "device"; constexpr char kFuncAttr[] = "func"; struct ClusterOutliningPass @@ -43,28 +42,29 @@ struct ClusterOutliningPass void runOnOperation() override; }; -void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op, - OpBuilder* builder) { - builder->create(launch_return_op.getLoc(), - launch_return_op.getOperands()); - launch_return_op.erase(); +void ReplaceClusterReturnWithReturn(tf_device::ReturnOp cluster_return_op, + OpBuilder* builder) { + builder->create(cluster_return_op.getLoc(), + cluster_return_op.getOperands()); + cluster_return_op.erase(); } -// Builds a function that outlines region attached to launch_op and inserts +// Builds a function that outlines region attached to cluster_op and inserts // built function into given module. -FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, - tf_device::LaunchOp launch_op, SymbolTable* symbol_table, +FuncOp BuildFunction(llvm::ArrayRef live_ins, + tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, OpBuilder* builder) { llvm::SmallVector operand_types; operand_types.reserve(live_ins.size()); for (Value v : live_ins) operand_types.emplace_back(v.getType()); - auto func_type = FunctionType::get(operand_types, launch_op.getResultTypes(), + auto func_type = FunctionType::get(operand_types, cluster_op.getResultTypes(), builder->getContext()); - std::string func_name_prefix = Twine(device, "_func").str(); + // TODO(lyandy): Define better name for outlined function. Potentially some + // name can be added during cluster formation. FuncOp outlined_func = - FuncOp::create(launch_op.getLoc(), func_name_prefix, func_type); + FuncOp::create(cluster_op.getLoc(), "_func", func_type); // This function is not externally visible and marking it private would allow // symbol-dce pass to remove it when it is not referenced anymore. @@ -73,64 +73,59 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, // Create function body. Block* outlined_func_block = outlined_func.addEntryBlock(); - // Replace uses of live-in values within launch_op region with function + // Replace uses of live-in values within cluster_op region with function // arguments. - Region& launch_op_region = launch_op.body(); - for (const auto& p : - llvm::zip(live_ins, outlined_func_block->getArguments())) { + Region& cluster_op_region = cluster_op.body(); + for (auto p : llvm::zip(live_ins, outlined_func_block->getArguments())) { replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), - launch_op_region); + cluster_op_region); } - // Move all instructions in launch_op into outlined_function's only block. - auto& launch_op_body = launch_op_region.front().getOperations(); + // Move all instructions in cluster_op into outlined_function's only block. + auto& cluster_op_body = cluster_op.GetBody().getOperations(); outlined_func_block->getOperations().splice( - outlined_func_block->end(), launch_op_body, launch_op_body.begin(), - launch_op_body.end()); + outlined_func_block->end(), cluster_op_body, cluster_op_body.begin(), + cluster_op_body.end()); - // Replace `tf_device.launch_return` terminator with `std.return` in function + // Replace `tf_device.return` terminator with `std.return` in function // body. - auto launch_return_op = + auto cluster_return_op = cast(outlined_func_block->getTerminator()); - builder->setInsertionPoint(launch_return_op); - ReplaceLaunchReturnWithReturn(launch_return_op, builder); + builder->setInsertionPoint(cluster_return_op); + ReplaceClusterReturnWithReturn(cluster_return_op, builder); symbol_table->insert(outlined_func); return outlined_func; } -// Outlines body of `tf_device.launch` into a function and create a -// `tf_device.launch_func` to invoke that function. `tf_device.launch` is +// Outlines body of `tf_device.cluster` into a function and create a +// `tf_device.cluster_func` to invoke that function. `tf_device.cluster` is // removed afterwards.` -void OutlineLaunch(tf_device::LaunchOp launch_op, SymbolTable* symbol_table, - OpBuilder* builder) { +void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table, + OpBuilder* builder) { llvm::SetVector live_ins; - getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins); + getUsedValuesDefinedAbove(cluster_op.body(), cluster_op.body(), live_ins); - StringRef device = - launch_op.getAttrOfType(kDeviceAttr).getValue(); + FuncOp outlined_func = + BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder); + cluster_op.setAttr(builder->getIdentifier(kFuncAttr), + builder->getSymbolRefAttr(outlined_func.getName())); - FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(), - launch_op, symbol_table, builder); - launch_op.setAttr(builder->getIdentifier(kFuncAttr), - builder->getSymbolRefAttr(outlined_func.getName())); + builder->setInsertionPoint(cluster_op); + auto cluster_func_op = builder->create( + cluster_op.getLoc(), outlined_func.getType().getResults(), + live_ins.getArrayRef(), cluster_op.getAttrs()); - builder->setInsertionPoint(launch_op); - tf_device::LaunchFuncOp launch_func_op = - builder->create( - launch_op.getLoc(), outlined_func.getType().getResults(), - live_ins.getArrayRef(), launch_op.getAttrs()); - - launch_op.replaceAllUsesWith(launch_func_op); - launch_op.erase(); + cluster_op.replaceAllUsesWith(cluster_func_op); + cluster_op.erase(); } void ClusterOutliningPass::runOnOperation() { - ModuleOp m = getOperation(); - SymbolTable symbol_table(m); - OpBuilder builder(m.getContext()); - m.walk([&](tf_device::LaunchOp launch) { - OutlineLaunch(launch, &symbol_table, &builder); + ModuleOp module = getOperation(); + SymbolTable symbol_table(module); + OpBuilder builder(module.getContext()); + module.walk([&](tf_device::ClusterOp cluster) { + OutlineCluster(cluster, &symbol_table, &builder); }); } @@ -142,7 +137,7 @@ std::unique_ptr> CreateClusterOutliningPass() { static PassRegistration pass( "tf-device-cluster-outlining", - "Outline regions of tf_device.launch operations."); + "Outline regions of tf_device.cluster operations."); } // namespace TFDevice } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 92fa4e74a68..81d0259d2d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -91,6 +91,10 @@ std::unique_ptr> CreateResourceDeviceInferencePass(); // of their aliasing output arguments. std::unique_ptr> CreatePromoteResourcesToArgsPass(); +// Creates a pass that promotes tf.VarHandleOp to resource arguments for all +// functions. +std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); + // Marks function visibility using tf.entry_function specification. That is, // functions with tf.entry_function attributes are marked with public // visibility while the other functions are marked with private visibility. @@ -256,6 +260,11 @@ std::unique_ptr> CreateTPUMergeVariablesWithExecutePass(); // run-time according to compilation result. std::unique_ptr> CreateTPUVariableReformattingPass(); +// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster) +// at head/tail of TPU cluster to run before/after TPU computation. +std::unique_ptr> +CreateTPUExtractHeadTailOutsideCompilationPass(); + // Creates a pass that extract outside compilation (CPU ops inside TPU cluster) // ops to a separate parallel_execute region to run on CPU. std::unique_ptr> CreateTPUExtractOutsideCompilationPass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index fa4fe461317..cece23b4750 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -47,11 +47,14 @@ limitations under the License. // . Dead functions have already been removed, as resource arguments in dead // functions can cause the pass to fail. +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -73,114 +76,189 @@ constexpr char kResourceFunctionMsg[] = "expects function level resource argument"; constexpr char kInvalidResourceMsg[] = "expects resource to be a VarHandleOp or function argument"; +constexpr char kResourceNameArgAttr[] = "tf.resource_name"; -// Records the input argument index and the current live value for a resource -// variable. -// -// . If the input argument already exists or has been added, input_index is the -// index of the function, and live_value_or_type tracks the live value of the -// resource. -// -// . If the input argument has not been added in the pass, input_index is -// kInputUnassigned, live_value_or_type represents the type of the resource. -// (a) If this resource is read, add a new argument whose type is obtained -// from live_value_or_type, and input_index and live_value_or_type will be -// updated to reference the new argument. -// (b) If this resource is written, live_value_or_type will track the new -// value of the resource. input_index will remain to be kInputUnassigned. +// Checks if a function has only one block. +mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) { + if (!hasSingleElement(function.getBlocks())) + return function.emitError() + << "expects function '" << function.getName() + << "' to have 1 block, got " << function.getBlocks().size(); + + return success(); +} + +// Collects names of users of a resource that are not `tf.ReadVariableOp` and +// not `tf.AssignVariableOp`. +llvm::SmallSet GetCompositeResourceUserNames( + Value resource) { + // SmallSet will use a vector when there is only one element and use std::set + // when there are more than one elements. This ensures that the operations in + // the error message are ordered. + llvm::SmallSet composite_users; + for (Operation* user : resource.getUsers()) + if (!llvm::isa(user) && + !llvm::isa(user)) + composite_users.insert(user->getName().getStringRef()); + + return composite_users; +} + +// Checks if `tf.VarHandleOp` has a valid resource subtype and its users are of +// `tf.ReadVariableOp` and `tf.AssignVariableOp` only. +mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) { + auto resource_type = + getElementTypeOrSelf(var_handle_op.getType()).cast(); + if (resource_type.getSubtypes().size() != 1) + return var_handle_op.emitOpError() + << "expects resource type to have one subtype, got " + << resource_type; + + auto composite_ops = GetCompositeResourceUserNames(var_handle_op); + if (!composite_ops.empty()) + return var_handle_op.emitOpError() + << "expects users to be 'tf.ReadVariableOp' or " + "'tf.AssignVariableOp', got [" + << llvm::join(composite_ops.begin(), composite_ops.end(), ", ") + << "]"; + + return success(); +} + +// Checks if resource argument has a valid resource subtype and its users are of +// `tf.ReadVariableOp` and `tf.AssignVariableOp` only. +mlir::LogicalResult ValidateResourceArgument(FuncOp function, + BlockArgument resource_arg, + TF::ResourceType resource_type) { + if (resource_type.getSubtypes().size() != 1) + return function.emitError() + << "expects resource type of argument " + << resource_arg.getArgNumber() << " to have one subtype, got " + << resource_type; + + auto composite_ops = GetCompositeResourceUserNames(resource_arg); + if (!composite_ops.empty()) + return function.emitError() + << "expects users of resource argument " + << resource_arg.getArgNumber() + << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got [" + << llvm::join(composite_ops.begin(), composite_ops.end(), ", ") + << "]"; + + return success(); +} + +// Adds resource arguments for every unique (name) variable handle. Associated +// `tf.VarHandleOp` are removed from the function. Variable shared names are +// returned in `var_handle_shared_names` based on the ordering of added resource +// arguments. +mlir::LogicalResult PromoteVarHandlesToArguments( + FuncOp function, bool add_validation, + llvm::SmallVectorImpl* var_handle_shared_names) { + Block& block = function.front(); + auto func_type = function.getType(); + + auto func_arg_types = llvm::to_vector<4>(func_type.getInputs()); + llvm::SmallDenseMap var_arg_index_by_name; + for (auto var_handle_op : + llvm::make_early_inc_range(block.getOps())) { + if (add_validation && failed(ValidateVarHandle(var_handle_op))) + return failure(); + + llvm::StringRef name = var_handle_op.shared_nameAttr().getValue(); + auto it = var_arg_index_by_name.insert({name, func_arg_types.size()}); + if (it.second) { + var_handle_shared_names->emplace_back(name); + auto resource_type = var_handle_op.resource().getType(); + func_arg_types.push_back(resource_type); + var_handle_op.resource().replaceAllUsesWith( + block.addArgument(resource_type)); + } else { + var_handle_op.resource().replaceAllUsesWith( + block.getArgument(it.first->getSecond())); + } + var_handle_op.erase(); + } + + if (!var_handle_shared_names->empty()) + function.setType(FunctionType::get(func_arg_types, func_type.getResults(), + function.getContext())); + + return success(); +} + +// Records the current live value for a resource variable and whether a read or +// write on the variable occurred. struct ResourceInfo { - static constexpr int64_t kInputUnassigned = -1; - int64_t input_index; - llvm::PointerUnion live_value_or_type; + Value live_value = nullptr; + bool read = false; + bool write = false; }; -using ArgOrName = llvm::PointerUnion; -using ResourceMap = llvm::SmallDenseMap; - -LogicalResult PromoteResourcesToArguments(FuncOp function) { +LogicalResult PromoteResourcesToArguments( + FuncOp function, llvm::ArrayRef var_handle_shared_names) { Block& block = function.front(); auto return_op = llvm::dyn_cast_or_null(block.getTerminator()); if (!return_op) - return function.emitError( - "expects 'main' function to have a MLIR ReturnOp"); + return function.emitError() << "expects function '" << function.getName() + << "' to have a MLIR ReturnOp"; - ResourceMap resource_map; + llvm::SmallVector resources(function.getNumArguments()); auto argument_types = llvm::to_vector<4>(function.getType().getInputs()); + bool has_resources = false; + auto add_resource_argument = [&](BlockArgument arg, + TF::ResourceType resource_type) { + Type arg_type = resource_type.getSubtypes().front(); + arg.setType(arg_type); + resources[arg.getArgNumber()].live_value = arg; + argument_types[arg.getArgNumber()] = arg_type; + has_resources = true; + }; - // Loop through the resource arguments in the function and store a mapping - // from that argument to its index and itself as the current live value. - for (BlockArgument& func_arg : function.getArguments()) { + // Loop through the non `tf.VarHandleOp` resource arguments in the function, + // validate its uses and subtype, and store a mapping from that argument to + // itself as the current live value. + auto func_args = function.getArguments().take_front( + function.getNumArguments() - var_handle_shared_names.size()); + for (BlockArgument& func_arg : func_args) { auto resource_type = getElementTypeOrSelf(func_arg.getType()).dyn_cast(); if (!resource_type) continue; - if (resource_type.getSubtypes().size() != 1) - return function.emitError() - << "expects resource type of argument " << func_arg.getArgNumber() - << " to have one subtype, got " << resource_type; + if (failed(ValidateResourceArgument(function, func_arg, resource_type))) + return failure(); - for (auto* user : func_arg.getUsers()) - if (!llvm::isa(user) && - !llvm::isa(user)) - return function.emitError() - << "expects users of resource argument " - << func_arg.getArgNumber() - << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp'"; - - Type arg_type = resource_type.getSubtypes().front(); - func_arg.setType(arg_type); - resource_map[func_arg] = {func_arg.getArgNumber(), func_arg}; - argument_types[func_arg.getArgNumber()] = arg_type; + add_resource_argument(func_arg, resource_type); } - // Loop through the VarHandleOp in the function. When the first VarHandleOp - // for a resource variable is encountered, add an entry to the resource_map to - // record the information. Do not add a new function argument yet. - for (auto var_handle_op : block.getOps()) { - if (resource_map.count(var_handle_op.shared_nameAttr())) continue; - + // Loop through `tf.VarHandleOp` resource arguments in the function and store + // a mapping from that argument to itself as the current live value. No + // validations are necessary here as these arguments were validated prior to + // being added. + auto var_handle_args = + function.getArguments().take_back(var_handle_shared_names.size()); + for (BlockArgument& var_handle_arg : var_handle_args) { auto resource_type = - getElementTypeOrSelf(var_handle_op.getType()).cast(); - if (resource_type.getSubtypes().size() != 1) - return var_handle_op.emitOpError() - << "expects resource type to have one subtype, got " - << resource_type; - - resource_map[var_handle_op.shared_nameAttr()] = { - ResourceInfo::kInputUnassigned, resource_type.getSubtypes().front()}; + getElementTypeOrSelf(var_handle_arg.getType()).cast(); + add_resource_argument(var_handle_arg, resource_type); } - if (resource_map.empty()) return success(); + if (!has_resources) return success(); // We initially assign the argument for a resource as the live value for the // resource. We then walk through the operations in the function in their // lexical order, to update the live value for the resource when we see a // store to the resource and replace reads of the resource with uses of its - // live value. For the reads, if the resource does not have a live value yet, - // we add a new argument and use it as the live value. + // live value. for (Operation& op : llvm::make_early_inc_range(block)) { if (auto read_op = llvm::dyn_cast(&op)) { if (auto func_arg = read_op.resource().dyn_cast()) { if (func_arg.getOwner() != &block) return read_op.emitOpError(kResourceFunctionMsg); - // resource_map[func_arg] is always a Value when func_arg is a - // BlockArgument. - read_op.value().replaceAllUsesWith( - resource_map[func_arg].live_value_or_type.get()); - } else if (auto var_handle_op = llvm::dyn_cast( - read_op.resource().getDefiningOp())) { - ResourceInfo& info = resource_map[var_handle_op.shared_nameAttr()]; - if (auto live_value = info.live_value_or_type.dyn_cast()) { - read_op.value().replaceAllUsesWith(live_value); - } else { - auto arg_type = info.live_value_or_type.get(); - BlockArgument arg = block.addArgument(arg_type); - info.input_index = argument_types.size(); - info.live_value_or_type = arg; - argument_types.push_back(arg_type); - read_op.value().replaceAllUsesWith(arg); - } + ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; + resource_info.read = true; + read_op.value().replaceAllUsesWith(resource_info.live_value); } else { return read_op.emitOpError(kInvalidResourceMsg); } @@ -191,11 +269,9 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { if (func_arg.getOwner() != &block) return write_op.emitOpError(kResourceFunctionMsg); - resource_map[func_arg].live_value_or_type = write_op.value(); - } else if (auto var_handle_op = llvm::dyn_cast( - write_op.resource().getDefiningOp())) { - resource_map[var_handle_op.shared_nameAttr()].live_value_or_type = - write_op.value(); + ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; + resource_info.write = true; + resource_info.live_value = write_op.value(); } else { return read_op.emitOpError(kInvalidResourceMsg); } @@ -206,67 +282,68 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { const int64_t num_results_before = function.getNumResults(); auto return_operands = llvm::to_vector<4>(return_op.getOperands()); - return_operands.reserve(num_results_before + resource_map.size()); auto result_types = llvm::to_vector<4>(return_op.getOperandTypes()); - result_types.reserve(num_results_before + resource_map.size()); - llvm::SmallVector, 4> output_only_resources; - output_only_resources.reserve(resource_map.size()); + llvm::SmallVector, 4> + output_only_resources; llvm::SmallVector, 4> input_output_alias; - input_output_alias.reserve(resource_map.size()); - // Collect new return values and either (a) output-only resource attributes - // (if the resource is not promoted to an argument) or (b) mapping from - // resource input index to output alias (if the resource has been promoted to - // an argument). If the last live value is itself (argument), then that live - // value will not be returned as the resource is unmodified. - for (auto& resource : resource_map) { - int64_t input_index = resource.getSecond().input_index; - auto live_value = resource.getSecond().live_value_or_type.dyn_cast(); - if (input_index == ResourceInfo::kInputUnassigned) { - if (!live_value) continue; - - output_only_resources.push_back( - {return_operands.size(), resource.getFirst().dyn_cast()}); - } else { - // live_value is not nullptr because any input-assigned resource has a - // Value as live_value. - auto live_arg = live_value.dyn_cast(); - if (live_arg && live_arg.getOwner() == &block && - live_arg.getArgNumber() == input_index) - continue; - - input_output_alias.push_back({input_index, return_operands.size()}); - } - return_operands.push_back(live_value); - result_types.push_back(live_value.getType()); - } - - // Erase all VarHandleOp. - for (Operation& op : llvm::make_early_inc_range(function.front())) { - auto var_handle_op = llvm::dyn_cast(op); - if (!var_handle_op) continue; - if (!var_handle_op.use_empty()) { - // SmallSet will use a vector when there is only one element and use - // std::set when there are more than one elements. This ensures that - // the operations in the error message are ordered. - llvm::SmallSet unique_operations; - llvm::for_each( - var_handle_op.getOperation()->getUsers(), [&](Operation* user) { - unique_operations.insert(user->getName().getStringRef().str()); - }); - - return var_handle_op.emitOpError( - "expects no uses but used by operations: ") - << llvm::join(unique_operations.begin(), unique_operations.end(), - ", "); - } - - op.erase(); - } - - // Rewrite return if more results need to be returned by the function. + // Collect new return values for variable writes and either (a) output-only + // resource attributes (if the resource is not promoted to an argument) or (b) + // mapping from resource input index to output alias (if the resource has been + // promoted to an argument). Resource arguments that were originally + // `tf.VarHandleOp` but not read are collected and then removed. OpBuilder builder(return_op); - if (!output_only_resources.empty() || !input_output_alias.empty()) { + const int var_handles_start_idx = + function.getNumArguments() - var_handle_shared_names.size(); + int new_argument_index = 0; + llvm::SmallVector argument_indices_to_remove; + for (auto resource_and_index : llvm::enumerate(resources)) { + const auto& resource = resource_and_index.value(); + if (!resource.live_value) { + // Ignore non resource arguments. + ++new_argument_index; + continue; + } + + const auto index = resource_and_index.index(); + const bool is_var_handle = index >= var_handles_start_idx; + if (resource.write) { + if (!is_var_handle || resource.read) { + input_output_alias.push_back( + {new_argument_index, return_operands.size()}); + } else if (is_var_handle) { + output_only_resources.push_back( + {return_operands.size(), + var_handle_shared_names[index - var_handles_start_idx]}); + } + return_operands.push_back(resource.live_value); + result_types.push_back(resource.live_value.getType()); + } + + if (is_var_handle && !resource.read) { + assert(block.getArgument(index).getUses().empty()); + argument_indices_to_remove.push_back(index); + } else { + if (is_var_handle) { + // Add resource_name attribute to VarHandleOp read. + function.setArgAttr( + new_argument_index, kResourceNameArgAttr, + builder.getStringAttr( + var_handle_shared_names[index - var_handles_start_idx])); + } + ++new_argument_index; + } + } + + // Remove unread var handle arguments. + for (int argument_index_to_remove : + llvm::reverse(argument_indices_to_remove)) { + block.eraseArgument(argument_index_to_remove); + argument_types.erase(argument_types.begin() + argument_index_to_remove); + } + + // Rewrite return if there are variable writes. + if (return_operands.size() > num_results_before) { builder.create(return_op.getLoc(), return_operands); return_op.erase(); } @@ -274,17 +351,10 @@ LogicalResult PromoteResourcesToArguments(FuncOp function) { // Update function argument and result types with new resource subtypes. function.setType(builder.getFunctionType(argument_types, result_types)); - // Add resource_name attribute to the input argument for the resources. - for (auto& resource : resource_map) { - if (auto attr = resource.getFirst().dyn_cast()) { - int64_t input_index = resource.getSecond().input_index; - if (input_index != ResourceInfo::kInputUnassigned) - function.setArgAttr(input_index, "tf.resource_name", attr); - } - } // Add resource_name attribute to the output for the resources. for (auto& resource : output_only_resources) - function.setResultAttr(resource.first, "tf.resource_name", resource.second); + function.setResultAttr(resource.first, kResourceNameArgAttr, + builder.getStringAttr(resource.second)); // Add aliasing_output attribute to the input argument for the resources that // are updated by the function. @@ -309,26 +379,60 @@ void PromoteResourcesToArgsPass::runOnOperation() { // This routine should only be called when control flow operations are still // represented with TF IfOp and WhileOp operations. In this case, there should // be only one basic blocks in the MLIR representation. - if (!hasSingleElement(main_func.getBlocks())) { - main_func.emitError() << "expects 'main' function to have 1 block, got " - << main_func.getBlocks().size(); - return signalPassFailure(); - } + if (failed(CheckSingleBlockFunction(main_func))) return signalPassFailure(); + llvm::SmallVector var_handle_shared_names; if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) || - failed(PromoteResourcesToArguments(main_func))) + failed(PromoteVarHandlesToArguments(main_func, /*add_validation=*/true, + &var_handle_shared_names)) || + failed(PromoteResourcesToArguments(main_func, var_handle_shared_names))) return signalPassFailure(); } +class PromoteVarHandlesToArgsPass + : public PassWrapper> { + public: + void runOnOperation() override; +}; + +void PromoteVarHandlesToArgsPass::runOnOperation() { + ModuleOp module = getOperation(); + MLIRContext* context = module.getContext(); + for (auto function : module.getOps()) { + if (failed(CheckSingleBlockFunction(function))) return signalPassFailure(); + + llvm::SmallVector var_handle_shared_names; + PromoteVarHandlesToArguments(function, /*add_validation=*/false, + &var_handle_shared_names); + + // Add resource names for each `tf.VarHandleOp` that were promoted to + // resource arguments. + const int var_handle_args_offset = + function.getNumArguments() - var_handle_shared_names.size(); + for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names)) + function.setArgAttr(var_name_and_index.index() + var_handle_args_offset, + kResourceNameArgAttr, + StringAttr::get(var_name_and_index.value(), context)); + } +} + } // namespace std::unique_ptr> CreatePromoteResourcesToArgsPass() { return std::make_unique(); } +std::unique_ptr> CreatePromoteVarHandlesToArgsPass() { + return std::make_unique(); +} + static PassRegistration pass( "tf-promote-resources-to-args", "Promote resources reads/writes to function inputs/outputs."); +static PassRegistration var_handle_pass( + "tf-promote-var-handles-to-args", + "Promote tf.VarHandleOps to function arguments."); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 30bc1a21075..2fd230005d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -107,10 +108,9 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // Creates islands per replica from `tf_device.replicate` region and remap // replicate results with new island outputs. A single island is created to -// forward results from each replica island. Control dependencies of individual -// replicas are added to the single island if the single island does not emit -// a result from the respective replica. Devices are remapped from aliased -// devices to explicit devices, for `tf_device.launch` ops. +// forward control dependencies if there is a control dependency output from the +// replicate island. Devices are remapped from aliased devices to explicit +// devices, for `tf_device.launch` ops. // // For example, the following: // @@ -156,12 +156,9 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // }) {device = "/DEVICE:3"} : () -> tensor // tf_executor.yield %a1, %b1 : tensor, tensor // } -// %6:2 = tf_executor.island(%3#2) { -// tf_executor.yield %0#0 : tensor -// } -LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, - tf_executor::IslandOp island_op, - tf_device::ReplicateOp replicate_op) { +void CreateIslandsFromReplicate(const Dialect* tf_dialect, + tf_executor::IslandOp island_op, + tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); const int num_replicas = replicate_op.n().getLimitedValue(); @@ -181,45 +178,38 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, replica_result_and_idx.value(); // Remap replicate results to per replica result. - replicate_op.replaceAllUsesWith(replicas_outputs); + for (auto result : llvm::zip(island_op.outputs(), replicas_outputs)) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - // Collect per replica control dependency and add to island operand if replica - // island has no uses. - llvm::SmallVector island_operands; - for (auto& replica : replicas) - if (replica.use_empty()) island_operands.push_back(replica.control()); + // Add sink island to pin all replicas as a control dependency if there is a + // control dependency leading from the replicate originally. + if (!island_op.control().use_empty()) { + llvm::SmallVector island_operands; + for (auto& replica : replicas) island_operands.push_back(replica.control()); - // Create single island forwarding per replica result. - builder.setInsertionPoint(island_op); - auto island_sink = builder.create( - island_op.getLoc(), - llvm::to_vector<8>(island_op.GetYield().fetches().getTypes()), - tf_executor::ControlType::get(island_op.getContext()), island_operands); - island_sink.body().push_back(new Block); - - // Move replicate island YieldOp over to new single island. - island_op.GetYield().getOperation()->moveBefore( - &island_sink.GetBody(), island_sink.GetBody().begin()); - - // Remap island results. - island_op.replaceAllUsesWith(island_sink); + builder.setInsertionPoint(island_op); + auto island_sink = builder.create( + island_op.getLoc(), llvm::ArrayRef{}, + tf_executor::ControlType::get(island_op.getContext()), island_operands); + island_sink.body().push_back(new Block); + builder.setInsertionPointToEnd(&island_sink.GetBody()); + builder.create(island_op.getLoc(), + llvm::ArrayRef{}); + island_op.control().replaceAllUsesWith(island_sink.control()); + } island_op.erase(); - return success(); } // Finds islands with a single `tf_device.replicate` and create individual // islands per replica of the replicate. -LogicalResult LowerSingleIslandReplicateToIslands( - const Dialect* tf_dialect, tf_executor::IslandOp island_op) { - if (!hasSingleElement(island_op.GetBody().without_terminator())) - return success(); +void LowerSingleIslandReplicateToIslands(const Dialect* tf_dialect, + tf_executor::IslandOp island_op) { + if (!island_op.WrapsSingleOp()) return; if (auto replicate_op = llvm::dyn_cast(&island_op.GetBody().front())) - return CreateIslandsFromReplicate(tf_dialect, island_op, replicate_op); - - return success(); + CreateIslandsFromReplicate(tf_dialect, island_op, replicate_op); } void ReplicateToIslandPass::runOnFunction() { @@ -229,13 +219,9 @@ void ReplicateToIslandPass::runOnFunction() { getFunction().emitError() << "'tf' dialect is not registered"; } - auto result = getFunction().walk([&](tf_executor::IslandOp island_op) { - if (failed(LowerSingleIslandReplicateToIslands(tf_dialect, island_op))) - return WalkResult::interrupt(); - return WalkResult::advance(); + getFunction().walk([&](tf_executor::IslandOp island_op) { + LowerSingleIslandReplicateToIslands(tf_dialect, island_op); }); - - if (result.wasInterrupted()) return signalPassFailure(); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index faacaad4c98..611c4d2725a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -62,7 +62,7 @@ namespace { // TensorFlow resource variable and returns new value: // // %resource_handle = "tf.VarHandleOp"() -// %1 = "tf_device.launch"() ( { +// %1 = "tf_device.cluster"() ( { // %init_value = "tf.ReadVariableOp"(%resource_handle) // "tf.AssignAddVariableOp"(%resource_handle, %init_value) // %new_value = "tf.ReadVariableOp"(%resource_handle) @@ -73,7 +73,7 @@ namespace { // // %resource_handle = "tf.VarHandleOp"() // %init_value = "tf.ReadVariableOp"(%resource_handle) -// %1:2 = "tf_device.launch"() ( { +// %1:2 = "tf_device.cluster"() ( { // %new_value = "tf.AddV2"(%init_value, %init_value) // tf_device.return %new_value, %new_value // }) @@ -81,7 +81,7 @@ namespace { // // You can see that there are a few main changes applied: // 1) All the resource variable reads and writes are now outside of -// tf_device.launch op. +// tf_device.cluster op. // 2) Instead of taking resource handles as input, this device computation now // takes snapshotted values of that device. // 3) Some resource load operations are eliminated with store-load forwarding. @@ -89,13 +89,13 @@ namespace { // external resource store operations so that resources are still updated // after the computation. // -// If the launch body contains functional control flow, the pass first lifts the -// loads/stores in the body/cond/branch functions to the launch body, then +// If the cluster body contains functional control flow, the pass first lifts +// the loads/stores in the body/cond/branch functions to the cluster body, then // performs the above lifting. E.g., // -// func @launch_with_loop() -> () { +// func @cluster_with_loop() -> () { // %0 = "tf.VarHandleOp"() ... -// "tf_device.launch"() ( { +// "tf_device.cluster"() ( { // %1 = "tf.While"(%0) {body = @while_body, cond = @while_cond} // tf_device.return // }) @@ -113,10 +113,10 @@ namespace { // // will be be transformed to: // -// func @launch_with_loop() { +// func @cluster_with_loop() { // %0 = "tf.VarHandleOp"() ... // %1 = "tf.ReadVariableOp"(%0) -// %2 = "tf_device.launch"() ( { +// %2 = "tf_device.cluster"() ( { // %3 = "tf.While"(%1) {body = @while_body, cond = @while_cond} // tf_device.return %3 : tensor // }) : () -> tensor @@ -140,7 +140,7 @@ struct ResourceOpLiftingPass // such nodes to carry information. void RemoveIdentity(Block* block) { for (auto& op : llvm::make_early_inc_range(*block)) { - if (llvm::isa(&op) || llvm::isa(&op)) { + if (isa(&op) || isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } @@ -241,7 +241,7 @@ bool AppendResourceStoreValueToReturn(Block* body) { // TODO(ycao): Prevent same value from being returned multiple times. // TODO(ycao): Do not return resource store value if it is defined outside - // of launch_op. + // of cluster. new_return_operands.push_back(assign_variable_op.value()); has_resource_store = true; } @@ -256,81 +256,78 @@ bool AppendResourceStoreValueToReturn(Block* body) { return true; } -// Moves resource store operations to after launch_op. This assumes load-store -// forwarding has been performed on this launch_op such that there is at most -// one resource store operation carrying its final value. -tf_device::LaunchOp SinkResourceStores(tf_device::LaunchOp launch_op, - OpBuilder* builder) { - // Update ReturnOp inside launch_op's body to output final values of updated +// Moves resource store operations to after cluster. This assumes load-store +// forwarding has been performed on this cluster such that there is at most one +// resource store operation carrying its final value. +tf_device::ClusterOp SinkResourceStores(tf_device::ClusterOp cluster, + OpBuilder* builder) { + // Update ReturnOp inside cluster's body to output final values of updated // external resources. - if (!AppendResourceStoreValueToReturn(&launch_op.GetBody())) return launch_op; + if (!AppendResourceStoreValueToReturn(&cluster.GetBody())) return cluster; - auto new_return_op = launch_op.GetBody().getTerminator(); - llvm::SmallVector new_launch_return_types( - new_return_op->getOperandTypes()); + auto new_return_op = cluster.GetBody().getTerminator(); + llvm::SmallVector new_return_types(new_return_op->getOperandTypes()); - builder->setInsertionPoint(launch_op); - auto new_launch_op = builder->create( - launch_op.getLoc(), new_launch_return_types, - /*operands=*/llvm::SmallVector(), launch_op.getAttrs()); - new_launch_op.body().takeBody(launch_op.body()); + builder->setInsertionPoint(cluster); + auto new_cluster = builder->create( + cluster.getLoc(), new_return_types, + /*operands=*/llvm::SmallVector(), cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); - // Replace uses of old launch_op results with those of new_launch_op. - for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) { - std::get<0>(p).replaceAllUsesWith(std::get<1>(p)); - } + // Replace uses of old cluster results with those of new_cluster. + for (auto result : llvm::zip(cluster.getResults(), new_cluster.getResults())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); - // Create a mapping from operands of new_return_op operands to new_launch_op + // Create a mapping from operands of new_return_op operands to new_cluster // results. BlockAndValueMapping mapper; - for (auto p : - llvm::zip(new_return_op->getOperands(), new_launch_op.getResults())) { - mapper.map(std::get<0>(p), std::get<1>(p)); - } + for (auto operand_result : + llvm::zip(new_return_op->getOperands(), new_cluster.getResults())) + mapper.map(std::get<0>(operand_result), std::get<1>(operand_result)); // Clone all resource store ops and map their operands to values returned from - // new_launch_op. - for (Operation& op : llvm::make_early_inc_range(new_launch_op.GetBody())) { - if (dyn_cast(&op)) { + // new_cluster. + for (Operation& op : llvm::make_early_inc_range(new_cluster.GetBody())) { + if (isa(op)) { builder->clone(op, mapper); op.erase(); } } - launch_op.erase(); - return new_launch_op; + cluster.erase(); + return new_cluster; } -// Hoists resource variable loads and sinks stores from launch_op. -LogicalResult HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) { - ModuleOp m = launch_op.getParentOfType(); - OpBuilder builder(m); +// Hoists resource variable loads and sinks stores from cluster. +LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster, + ModuleOp module) { + OpBuilder builder(module); // Remove identity nodes to avoid aliasing. - RemoveIdentity(&launch_op.GetBody()); + RemoveIdentity(&cluster.GetBody()); // Perform store-load forwarding. So that each resource is only loaded with // its initial value and is only stored with its final value. - ForwardStoreToLoad(&launch_op.GetBody()); + ForwardStoreToLoad(&cluster.GetBody()); - // Move loads of external resources, if any, to before launch_op. - // (Skipping resources created inside of launch_op.) + // Move loads of external resources, if any, to before cluster. + // (Skipping resources created inside of cluster.) HoistResourceLoads( - &launch_op.GetBody(), + &cluster.GetBody(), /*skip_load=*/ [&](TF::ReadVariableOp read) { - return read.resource().getParentRegion() == &launch_op.body(); + return read.resource().getParentRegion() == &cluster.body(); }, /*move_load=*/ [&](TF::ReadVariableOp read) { - read.getOperation()->moveBefore(launch_op); + read.getOperation()->moveBefore(cluster); }); - // Move stores of external resources, if any, to after launch_op. - auto new_launch_op = SinkResourceStores(launch_op, &builder); + // Move stores of external resources, if any, to after cluster. + auto new_cluster = SinkResourceStores(cluster, &builder); llvm::SetVector captured_values; - getUsedValuesDefinedAbove(new_launch_op.body(), new_launch_op.body(), + getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(), captured_values); for (Value v : captured_values) { @@ -338,7 +335,7 @@ LogicalResult HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) { if (!tensor_type) continue; if (!tensor_type.getElementType().isa()) continue; - return new_launch_op.emitOpError() + return new_cluster.emitOpError() << "has remaining resource inputs that can not be lifted"; } @@ -378,8 +375,7 @@ LogicalResult FindResourceArgUseInfo( info.data_type = assign.value().getType(); continue; } - if (llvm::isa(user) || - llvm::isa(user)) { + if (isa(user) || isa(user)) { // Stacks will be handled by a separate pass. do_not_touch = true; break; @@ -1034,7 +1030,7 @@ LogicalResult HoistForFunctionalControlFlow( for (auto local_var : local_vars) { if (llvm::all_of(local_var.resource().getUsers(), [](const Operation* user) { - return llvm::isa(user); + return isa(user); })) { for (auto user : local_var.resource().getUsers()) user->erase(); local_var.erase(); @@ -1043,18 +1039,18 @@ LogicalResult HoistForFunctionalControlFlow( return success(); } -// Lifts resource operation from tf_device.launch_func ops nested in `op` -// outside. Returns failure if there are remaining resource-type values that can -// not be lifted. +// Lifts resource operation from tf_device.cluster ops nested in `op` outside. +// Returns failure if there are remaining resource-type values that can not be +// lifted. void ResourceOpLiftingPass::runOnOperation() { llvm::SmallDenseMap lifted_partitioned_call_callees; - auto result = getOperation().walk([&](FuncOp func_op) { - return func_op.walk([&](tf_device::LaunchOp launch_op) { + ModuleOp module = getOperation(); + auto result = module.walk([&](FuncOp func_op) { + return func_op.walk([&](tf_device::ClusterOp cluster) { if (failed(HoistForFunctionalControlFlow( - &launch_op.GetBody(), getOperation(), - &lifted_partitioned_call_callees)) || - failed(HoistResourceOpsFromLaunchOp(launch_op))) { + &cluster.GetBody(), module, &lifted_partitioned_call_callees)) || + failed(HoistResourceOpsFromCluster(cluster, module))) { return WalkResult::interrupt(); } return WalkResult::advance(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 38a1464ffcc..5a2cae38062 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" @@ -26,6 +28,7 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -56,12 +59,14 @@ limitations under the License. #define DEBUG_TYPE "tf-shape-inference" using ::tensorflow::int64; +using tensorflow::shape_inference::DimensionHandle; +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; namespace mlir { namespace TF { namespace { -Optional> InferShapeForFunctionReturnType( - FuncOp func) { +Optional> InferShapeForFunctionReturnType(FuncOp func) { // Find any return ops. SmallVector return_ops; for (Block& block : func) { @@ -121,19 +126,19 @@ bool IsSupportedNonTFOp(Operation* op) { // not a TF operation, as we can't guarantee that the new type will be OK. void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result, Dialect* tf_dialect, Type old_type) { - OpBuilder builder(op); - builder.setInsertionPointAfter(op); // A tf.Cast operation is lazily created on the first uses that isn't a TF // operation. TF::CastOp cast_op; auto get_cast_op = [&]() { - if (!cast_op) - cast_op = - builder.create(op->getLoc(), old_type, result, - /*truncate=*/builder.getBoolAttr(false)); - return mlir::Value(cast_op); + if (!cast_op) { + OpBuilder b(op); + b.setInsertionPointAfter(op); + cast_op = b.create(op->getLoc(), old_type, result, + /*truncate=*/b.getBoolAttr(false)); + } + return Value(cast_op); }; - for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) { + for (OpOperand& use : make_early_inc_range(result.getUses())) { if (use.getOwner()->getDialect() != tf_dialect && !IsSupportedNonTFOp(use.getOwner())) use.set(get_cast_op()); @@ -156,10 +161,22 @@ Optional GetShapeFromMlirType(Type t) { bool InferShapeForPassThroughOps(OperandRange pass_through_operands, Operation* op, Dialect* tf_dialect) { bool changed = false; - for (auto entry : llvm::zip(pass_through_operands, op->getResults())) { + for (auto entry : zip(pass_through_operands, op->getResults())) { Type operand_type = std::get<0>(entry).getType(); Value result = std::get<1>(entry); if (result.getType() == operand_type) continue; + // Pass through nodes may remove ref types, don't consider that as + // refinement. + // TODO(jpienaar): There could be refinement in addition to this, so + // refine this. + if (operand_type.cast() + .getElementType() + .isa() && + !result.getType() + .cast() + .getElementType() + .isa()) + continue; AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, result.getType()); result.setType(operand_type); @@ -186,7 +203,7 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) { tf_dialect); } // TODO(b/155227679): Use OpInterface instead of hard-coding for TensorCastOp. - if (auto tensor_cast = dyn_cast(op)) { + if (auto tensor_cast = dyn_cast(op)) { return InferShapeForPassThroughOps( tensor_cast.getOperation()->getOperands(), op, tf_dialect); } @@ -236,9 +253,22 @@ GetSubtypes(Type type) { // match the i-th operand type). Returns true if anything is changed. bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { bool changed = false; - for (auto entry : llvm::zip(operands, results)) { + for (auto entry : zip(operands, results)) { Type operand_type = std::get<0>(entry).getType(); - if (operand_type == std::get<1>(entry).getType()) continue; + Type result_type = std::get<1>(entry).getType(); + if (operand_type == result_type) continue; + // Pass through nodes may remove ref types, don't consider that as + // refinement. + // TODO(jpienaar): There could be refinement in addition to this, so + // refine this. + if (operand_type.cast() + .getElementType() + .isa() && + !result_type.cast() + .getElementType() + .isa()) + continue; + std::get<1>(entry).setType(operand_type); changed = true; } @@ -260,14 +290,13 @@ bool InferShapeForCall(Operation* op) { CallInterfaceCallable callable = call_op.getCallableForCallee(); SymbolRefAttr sym = callable.dyn_cast(); if (!sym) return false; - FuncOp func = - dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); + FuncOp func = dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); if (!func) return false; bool changed = false; // Map each of the results of the call to the returned type of the // function. - for (auto result : llvm::zip(op->getResults(), func.getType().getResults())) { + for (auto result : zip(op->getResults(), func.getType().getResults())) { if (std::get<0>(result).getType() == std::get<1>(result)) continue; // Skip already statically shaped results. if (!CanBeRefined(std::get<0>(result).getType())) continue; @@ -287,20 +316,293 @@ bool InferShapeForCall(Operation* op) { return changed; } -bool RefineTfConst(TF::ConstOp const_op) { - Type old_type = const_op.getType(); - if (const_op.valueAttr().getType() == old_type) return false; - const_op.getResult().setType(const_op.valueAttr().getType()); - AddCastBackForUnsupportedNonTFUses(const_op, const_op.getResult(), - const_op.getDialect(), old_type); - return true; +bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, + Dialect* tf_dialect) { + Operation* op = infer_ti.getOperation(); + SmallVector inferred; + LogicalResult res = infer_ti.inferReturnTypes( + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), inferred); + if (failed(res)) { + op->emitOpError("failed to refine type as inference failed"); + return false; + } + + if (inferred == op->getResultTypes()) return false; + + // Map each of the results of the call to the returned type of the + // function. + bool changed = false; + for (auto result : zip(op->getResults(), inferred)) { + if (std::get<0>(result).getType() == std::get<1>(result)) continue; + + // Inserts a cast back to the original type if any user is not in the + // TF dialect. + AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), + op->getDialect(), std::get<1>(result)); + // Finally we inferred the shape and replace the type for this result. + std::get<0>(result).setType(std::get<1>(result)); + changed = true; + } + return changed; } } // namespace -bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, - int64_t graph_version) { - assert(tf_dialect == op->getDialect()); +// Combination of value producer and port of value produced (e.g., +// :, +// so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output +// scalar value). +struct ValuePort { + PointerUnion producer; + SmallVector port; + + bool operator==(const ValuePort& other) const { + return producer == other.producer && port == other.port; + } + + // Convert output value to ValuePort. + explicit ValuePort(Value v) { + OpResult opr = v.dyn_cast(); + if (opr) { + producer = opr.getOwner(); + port = {opr.getResultNumber()}; + } else { + producer = v.cast(); + port = {0}; + } + } + ValuePort(PointerUnion producer, + SmallVector port) + : producer(producer), port(port) {} + + raw_ostream& print(raw_ostream& os) const { + if (auto* op = producer.dyn_cast()) + os << "op " << op->getName(); + if (auto ba = producer.dyn_cast()) + os << "block_arg " << ba.getArgNumber(); + os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); + return os; + } +}; + +struct ValuePortHasher { + std::size_t operator()(const ValuePort& other) const { + return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()), + hash_value(ArrayRef(other.port))); + } +}; + +using ValuePortResultMap = + std::unordered_map; +using ComputedQueryFn = function_ref; +using ValueQueryFn = function_ref; +using ValuePortInputs = SmallVectorImpl; + +// TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are +// intended to be switched to op interfaces once more refined. +LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, + ComputedQueryFn has_been_computed, + ValuePortInputs* inputs) { + auto op = value_port.producer.dyn_cast(); + auto& port = value_port.port; + if (!op) return failure(); + + // No inputs required for constants. + if (matchPattern(op, m_Constant())) return success(); + + // Note: this focusses only on the trivial pack op case and this could be + // generalized. + if (auto pack_op = dyn_cast(op)) { + if (pack_op.getType().cast().getRank() != 1) return failure(); + if (port.size() != 2) return failure(); + assert(port[0] == 0); + ValuePort req(pack_op.getOperand(port[1])); + if (!has_been_computed(req)) inputs->push_back(req); + return success(); + } + + return failure(); +} + +// Computes the output produced by ValuePort using the query function of +// existing computed values. +Attribute ComputeOutputComponent(const ValuePort& value_port, + ValueQueryFn values) { + LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for ")); + + auto op = value_port.producer.dyn_cast(); + if (!op) return nullptr; + auto& port = value_port.port; + + if (port.empty()) { + LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n"); + return nullptr; + } + + ElementsAttr attr; + if (matchPattern(op, m_Constant(&attr))) { + if (port.size() == 1 && port[0] == 0) return attr; + return nullptr; + } + + // Note: this focusses only on the trivial pack op case and this could be + // generalized. + if (auto pack_op = dyn_cast(op)) { + if (pack_op.getType().cast().getRank() != 1) return nullptr; + if (port.size() != 2 || port[0] != 0) return nullptr; + ValuePort op_port(op->getOperand(port[1])); + return values(op_port); + } + return nullptr; +} + +// Context used during ShapeInference. This class contains common information +// that is required by the individual shape inference helper functions (e.g., +// TF Graph version, constant values computed, etc.) +class ShapeInference { + public: + ShapeInference(int64_t graph_version, MLIRContext* context); + + LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, + ValuePortInputs* inputs) { + return ::mlir::TF::ComputeInputsRequiredForOutput( + value_port, + [this](const ValuePort& port) { + return results_.find(port) != results_.end(); + }, + inputs); + } + + Attribute ComputeOutputComponent(const ValuePort& value_port) { + return ::mlir::TF::ComputeOutputComponent( + value_port, [this](const ValuePort& port) { return results_[port]; }); + } + + // Returns ShapeHandle if the op result could be computed as shape. + ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic); + + void RecordValue(const ValuePort& value_port, Attribute value) { + results_[value_port] = value; + } + + // Performs shape inference on the provided op and return true if the type of + // at least one result has been changed. + // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect. + // `graph_version` indicates the current GraphDef compatibility versions + // (the versions field in graph.proto). + bool InferShapeForSingleOperation(Operation* op); + + // Infers shape on the provided region, including nested ones, iterate until + // fix point with a limit of max_iteration. Returns success if fix point is + // reached before max_iteration. + LogicalResult InferShapeUntilFixPoint(Region* region, + int64_t max_iteration = 10); + + // Updates input types and refine shapes inside body of functions that are + // attached to ControlFlow ops (If/While). These functions include Then/Else + // branches of IfOp and Cond/Body functions of WhileOp. These functions share + // following common properties: + // 1) They are never reused, ie. having a single use in module. + // 2) Their input types match those of their parent ops (excluding inputs + // like predicate). + // Returns a boolean indicating whether any change has been applied. + LogicalResult RefineShapeForControlFlowFunc(FuncOp func, + ArrayRef input_types, + int64_t max_iteration); + + // Propagate the shapes to the functions named. + LogicalResult PropagateShapeToFunctions( + ModuleOp module, Operation::operand_type_range input_types, + ArrayRef func_names, int64_t max_iteration); + + // Shape propagation for call/control flow ops. + LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, + int64_t max_iteration); + + private: + // Mapping between ValuePort (which corresponds to an OpResult or smaller, + // e.g., first element of OpResult produded) to an Attribute if the ValuePort + // corresponds to a constant value. + ValuePortResultMap results_; + int64_t graph_version_; + MLIRContext* context_; + Dialect* tf_dialect_; +}; + +ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context) + : graph_version_(graph_version) { + context_ = context; + tf_dialect_ = context->getRegisteredDialect(); +} + +ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, + InferenceContext* ic) { + LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially ")); + auto rt = result.getType().dyn_cast(); + if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {}; + int dim_size = rt.getDimSize(0); + + // Worklist to direct partial evaluation. + SmallVector worklist; + + // Simple evaluator that attempts to partially evaluate the input value even + // if unable to evaluate the complete output. Below follows a simple stack + // based evaluation where it queries what operands/part of operands need to + // be evaluated and attempting to partially evaluate those operands. It does + // so by pushing the operands that need to be required on to the worklist + // before enqueuing the operation requiering those values. + std::vector dims(dim_size, ic->UnknownDim()); + for (unsigned int i = 0, e = dims.size(); i != e; ++i) { + LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n"); + + worklist.push_back( + ValuePort{result.getOwner(), {result.getResultNumber(), i}}); + while (!worklist.empty()) { + auto front = worklist.pop_back_val(); + LLVM_DEBUG(front.print(llvm::errs() << "\nWorklist front ")); + + SmallVector inputs; + auto res = ComputeInputsRequiredForOutput(front, &inputs); + if (failed(res)) { + // Abort if unable to find which required inputs need to be computed. + worklist.clear(); + break; + } + + if (!inputs.empty()) { + // Enqueue required computation followed by its required operands in + // stack. + worklist.push_back(std::move(front)); + for (auto& it : inputs) worklist.push_back(std::move(it)); + continue; + } + + auto ret = ComputeOutputComponent(front); + if (!ret) continue; + + RecordValue(front, ret); + LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); + + // If worklist is empty, then this is the root query op. + if (worklist.empty()) { + LLVM_DEBUG(llvm::dbgs() << "[root node]\n"); + if (auto dea = ret.dyn_cast()) { + if (dea.getNumElements() != 1) { + LLVM_DEBUG(llvm::errs() << "Unexpected number of elements\n"); + return {}; + } + int64_t val = (*dea.getIntValues().begin()).getSExtValue(); + dims[i] = ic->MakeDim(val); + } + } + } + } + return ic->MakeShape(dims); +} + +bool ShapeInference::InferShapeForSingleOperation(Operation* op) { + assert(tf_dialect_ == op->getDialect()); // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough // to make sure they are preserved in the output. @@ -312,7 +614,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // If no result for this op needs shape inference, we have a fast-path return. // But if the type is a resource/variant, we do not skip it because we might // not have the handle shapes. - if (llvm::none_of(op->getResultTypes(), CanBeRefined)) { + if (none_of(op->getResultTypes(), CanBeRefined)) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" << op->getName() << "'.\n"); return false; @@ -327,8 +629,8 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // This is necessary to avoid reprocessing the tf.Cast that are inserted at // the end of this function. if (isa(op) && - llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) { - return user->getDialect() != tf_dialect; + all_of(op->getResult(0).getUsers(), [&](Operation* user) { + return user->getDialect() != tf_dialect_; })) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF " "dialect operation users '" @@ -408,9 +710,9 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, // Perform the shape inference using an InferenceContext with the input // shapes. This object is abstracting the information that the ShapeInference // function operates on. - tensorflow::shape_inference::InferenceContext c( - graph_version, *node_def, op_reg_data->op_def, input_shapes, - input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); + InferenceContext c(graph_version_, *node_def, op_reg_data->op_def, + input_shapes, input_tensors, + /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); auto status = c.Run(op_reg_data->shape_inference_fn); if (!status.ok()) { LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op @@ -418,6 +720,43 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, return false; } + // Determine if, during shape computation, the shape functions attempted to + // query an input operand as shape where the input was not known/constant. + bool requires_inputs = + any_of(llvm::seq(0, c.num_inputs()), [&](int input) { + return c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]; + }); + if (requires_inputs) { + std::vector input_tensors_as_shapes; + for (int input : llvm::seq(0, c.num_inputs())) { + if (c.requested_input_tensor_as_partial_shape(input) && + !input_tensors[input]) { + auto op_result = op->getOperand(input).dyn_cast(); + if (!op_result) continue; + // Resize on first valid shape computed. + input_tensors_as_shapes.resize(c.num_inputs()); + auto handle = ComputeOutputAsShape(op_result, &c); + LLVM_DEBUG(llvm::dbgs() << "Requested " << input << " as shape " + << (handle.Handle() ? "found" : "not found")); + if (handle.Handle()) input_tensors_as_shapes[input] = handle; + } + } + + // Attempt to compute the unknown operands as shapes. + // Note: in the case where no partial outputs could be computed, this would + // be empty. + if (!input_tensors_as_shapes.empty()) { + c.set_input_tensors_as_shapes(input_tensors_as_shapes); + auto status = c.Run(op_reg_data->shape_inference_fn); + if (!status.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Shape inference error for '" << *op + << "': " << status.error_message() << "\n"); + return false; + } + } + } + assert(c.num_outputs() == op->getNumResults() && "inference context matches the MLIR number of results."); @@ -430,12 +769,11 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, if (!CanBeRefined(result.getType())) continue; auto shaped_type = result.getType().cast(); - tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output); + ShapeHandle shape_handle = c.output(output); LLVM_DEBUG(llvm::dbgs() << "Inferred output " << output << " : " << c.DebugString(shape_handle) << "\n"); - auto get_tensor_type = - [&c](const tensorflow::shape_inference::ShapeHandle& sh, - Type element_type) -> TensorType { + auto get_tensor_type = [&c](const ShapeHandle& sh, + Type element_type) -> TensorType { if (!c.RankKnown(sh)) return UnrankedTensorType::get(element_type); // Convert the shape from TensorFlow (int64) to MLIR (int64_t). SmallVector shape; @@ -449,7 +787,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, new_element_type.isa()) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { - llvm::SmallVector subtypes; + SmallVector subtypes; OpBuilder b(op); for (const auto& shape_n_type : *handle_shapes_types) { Type element_type; @@ -469,7 +807,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, if (result.getType() == new_type) continue; // Inserts a cast back to the original type if any user is not in the TF // dialect. - AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, + AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_, result.getType()); // Finally we inferred the shape and replace the type for this result. result.setType(new_type); @@ -481,23 +819,13 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, return changed; } -// Updates input types and refine shapes inside body of functions that are -// attached to ControlFlow ops (If/While). These functions include Then/Else -// branches of IfOp and Cond/Body functions of WhileOp. These functions share -// following common properties: -// 1) They are never reused, ie. having a single use in module. -// 2) Their input types match those of their parent ops (excluding inputs like -// predicate). -// Returns a boolean indicating whether any change has been applied. -LogicalResult RefineShapeForControlFlowFunc(FuncOp func, - llvm::ArrayRef input_types, - int64_t graph_version, - int64_t max_iteration) { +LogicalResult ShapeInference::RefineShapeForControlFlowFunc( + FuncOp func, ArrayRef input_types, int64_t max_iteration) { ModuleOp module = func.getParentOfType(); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); if (num_uses != 1) { - func.emitWarning(llvm::formatv( + func.emitWarning(formatv( "expected control flow function {0} to have exactly 1 use, found {1}.", func.getName(), num_uses)); return failure(); @@ -511,8 +839,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, arg_and_idx.value().setType(input_types[arg_and_idx.index()]); } - auto res = - InferShapeUntilFixPoint(&func.getBody(), graph_version, max_iteration); + auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration); if (failed(res)) return res; auto new_return_types = InferShapeForFunctionReturnType(func); @@ -524,20 +851,18 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func, return success(); } -LogicalResult PropagateShapeToFunctions( +LogicalResult ShapeInference::PropagateShapeToFunctions( ModuleOp module, Operation::operand_type_range input_types, - llvm::ArrayRef func_names, int64_t graph_version, - int64_t max_iteration) { - bool success = true; + ArrayRef func_names, int64_t max_iteration) { + bool all_succeeded = true; auto types = llvm::to_vector<4>(input_types); for (auto func_name : func_names) { FuncOp func = module.lookupSymbol(func_name); - if (failed(RefineShapeForControlFlowFunc(func, types, graph_version, - max_iteration))) { - success = false; - } + all_succeeded = + succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) && + all_succeeded; } - return mlir::success(success); + return success(all_succeeded); } // If the callee has only one use, propagates any constant operand of call_op to @@ -557,7 +882,7 @@ void PropagateConstantToCallee(CallOpInterface call_op, // the constant inside the function. for (auto arg : func.getArguments()) { auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp(); - if (llvm::isa_and_nonnull(operand)) { + if (isa_and_nonnull(operand)) { arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0)); } } @@ -576,33 +901,31 @@ void PropagateConstantFromCallee(CallOpInterface call_op, for (auto retval : llvm::enumerate(func.front().getTerminator()->getOperands())) { auto retval_op = retval.value().getDefiningOp(); - if (llvm::isa_and_nonnull(retval_op)) { + if (isa_and_nonnull(retval_op)) { op->getResult(retval.index()) .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); } } } -LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, - int64_t graph_version, - int64_t max_iteration) { +LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( + Operation* op, int64_t max_iteration) { ModuleOp module = op->getParentOfType(); if (auto if_op = dyn_cast(op)) { return PropagateShapeToFunctions( - module, llvm::drop_begin(if_op.getOperandTypes(), 1), - {if_op.then_branch(), if_op.else_branch()}, graph_version, - max_iteration); + module, drop_begin(if_op.getOperandTypes(), 1), + {if_op.then_branch(), if_op.else_branch()}, max_iteration); } else if (auto while_op = dyn_cast(op)) { return PropagateShapeToFunctions(module, while_op.getOperandTypes(), {while_op.cond(), while_op.body()}, - graph_version, max_iteration); + max_iteration); } else if (auto call_op = dyn_cast(op)) { CallInterfaceCallable callable = call_op.getCallableForCallee(); if (SymbolRefAttr sym = callable.dyn_cast()) { PropagateConstantToCallee(call_op, sym, module); if (failed(PropagateShapeToFunctions( module, call_op.getArgOperands().getTypes(), - {sym.getRootReference()}, graph_version, max_iteration))) { + {sym.getRootReference()}, max_iteration))) { return failure(); } PropagateConstantFromCallee(call_op, sym, module); @@ -615,13 +938,10 @@ LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, return success(); } -LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, - int64_t max_iteration) { - MLIRContext* ctx = region->getContext(); - Dialect* tf_dialect = ctx->getRegisteredDialect(); - - // An operation folder that is used to attempt folding before inference. - OperationFolder folder(ctx); +LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, + int64_t max_iteration) { + // An operation folder that is used to attempt folding before inference._ + OperationFolder folder(context_); bool changed = true; // TODO(aminim): we could have a more efficient traversal by guiding the @@ -633,30 +953,29 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LLVM_DEBUG(llvm::dbgs() << "Shape inference, iteration " << iteration << "\n"); region->walk([&](Operation* op) { - if (op->getDialect() != tf_dialect) { - changed |= InferShapeForNonTFDialectOperation(op, tf_dialect); - return; - } - - if (auto tf_const = dyn_cast(op)) { - changed |= RefineTfConst(tf_const); + if (auto infer_ti = dyn_cast(op)) { + changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); // TODO(jpienaar): Debug why we can't just return here. We end up with // additional constant due to the propagation of constant into attached // function if we return already. } + if (op->getDialect() != tf_dialect_) { + changed |= InferShapeForNonTFDialectOperation(op, tf_dialect_); + return; + } + // Before attempting inference, just try to fold the operation. if (succeeded(folder.tryToFold(op))) return; // Best-effort shape inference in attached functions. Do not return // failure even if it doesn't get to fixed point. - if (failed(PropagateShapeIntoAttachedFunctions(op, graph_version, - max_iteration))) { + if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) { op->emitWarning() << "unable to refine shape of attached function " "arguments and bodies"; } - changed |= InferShapeForSingleOperation(op, tf_dialect, graph_version); + changed |= InferShapeForSingleOperation(op); }); } @@ -671,31 +990,43 @@ LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, int64_t graph_version) { - mlir::FunctionType func_type = func.getType(); + ShapeInference context(graph_version, func.getContext()); + if (arg_shapes.empty()) { + if (failed(context.InferShapeUntilFixPoint(&func.getBody()))) + return failure(); + // TODO(b/156276510): Verify that it is always fine to refine a function's + // return type, as long as we do not change the argument shapes. + if (auto return_types = InferShapeForFunctionReturnType(func)) { + func.setType(FunctionType::get(func.getType().getInputs(), + return_types.getValue(), + func.getContext())); + } + + return success(); + } + FunctionType func_type = func.getType(); bool needs_refinement = false; - llvm::SmallVector new_arg_types; + SmallVector new_arg_types; new_arg_types.reserve(func_type.getNumInputs()); // Update argument types in-place using the provided arg_shapes. for (size_t i = 0; i < func_type.getNumInputs(); ++i) { ArrayRef shape = arg_shapes[i]; - mlir::Type element_type; - if (auto input_ty = - func_type.getInput(i).dyn_cast()) { + Type element_type; + if (auto input_ty = func_type.getInput(i).dyn_cast()) { if (!input_ty || input_ty.getShape().size() != shape.size()) { return failure(); } element_type = input_ty.getElementType(); } else { - auto unranked_input_ty = - func_type.getInput(i).dyn_cast(); + auto unranked_input_ty = func_type.getInput(i).dyn_cast(); if (!unranked_input_ty) { return failure(); } element_type = unranked_input_ty.getElementType(); } - auto new_arg_type = mlir::RankedTensorType::get(shape, element_type); + auto new_arg_type = RankedTensorType::get(shape, element_type); if (new_arg_type != func_type.getInput(i)) { // If the new type is more detailed, trigger shape inference. func.getArgument(i).setType(new_arg_type); @@ -708,28 +1039,17 @@ LogicalResult InferShapeForFunction(FuncOp func, return success(); } - mlir::LogicalResult result = - mlir::TF::InferShapeUntilFixPoint(&func.getBody(), graph_version); + LogicalResult result = context.InferShapeUntilFixPoint(&func.getBody()); if (failed(result)) { return failure(); } auto return_types = InferShapeForFunctionReturnType(func); - func.setType(mlir::FunctionType::get(new_arg_types, - return_types.hasValue() - ? return_types.getValue() - : func.getType().getResults(), - func.getContext())); - - return success(); -} - -LogicalResult InferShapeForFunctionType(FuncOp func) { - if (auto return_types = InferShapeForFunctionReturnType(func)) { - func.setType(mlir::FunctionType::get(func.getType().getInputs(), - return_types.getValue(), - func.getContext())); - } + func.setType(FunctionType::get(new_arg_types, + return_types.hasValue() + ? return_types.getValue() + : func.getType().getResults(), + func.getContext())); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 0524ec678ed..e36d8d56d6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -27,30 +27,13 @@ namespace mlir { namespace TF { -// Performs shape inference on the provided op and return true if the type of -// at least one result has been changed. -// A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect. -// `graph_version` indicates the current GraphDef compatibility versions -// (the versions field in graph.proto). -bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect, - int64_t graph_version); - -// Infers shape on the provided region, including nested ones, iterate until fix -// point with a limit of max_iteration. Returns success if fix point is reached -// before max_iteration. -LogicalResult InferShapeUntilFixPoint(Region* region, int64_t graph_version, - int64_t max_iteration = 10); - // Given a list of refined shapes matching the function arguments of func, runs // shape inference over the function to propagate this updated information. +// If arg_shapes are empty, then argument shapes will be left unchanged. LogicalResult InferShapeForFunction(FuncOp func, ArrayRef> arg_shapes, int64_t graph_version); -// Refines the return type of the given function by folding tf.Cast that -// precedes the return instruction. -LogicalResult InferShapeForFunctionType(FuncOp func); - } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 48e4e77ce0f..acdfc0eb039 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -58,10 +58,8 @@ struct ShapeInference } int64_t producer = producer_or.ValueOrDie(); for (auto func : module.getOps()) { - InferShapeUntilFixPoint(&func.getBody(), producer); - // TODO(yuanzx): Verify that it is always fine to refine a function's - // return type, as long as we do not change the argument shapes. - InferShapeForFunctionType(func); + if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer))) + return signalPassFailure(); } } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc index 0eafdea0964..e62df78ed11 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -41,15 +41,15 @@ using ::mlir::TF::ConstOp; class ExecutorConstantSinking : public mlir::PassWrapper { void runOnFunction() override { - getFunction().walk([](tf_device::LaunchOp launch) { - LLVM_DEBUG(llvm::dbgs() << "Visit " << *launch.getOperation() << "\n"); + getFunction().walk([](tf_device::ClusterOp cluster) { + LLVM_DEBUG(llvm::dbgs() << "Visit " << *cluster.getOperation() << "\n"); // For each launch op, we find the values used that come from a constant // defined above and sink these constants in the region body. // The sunk_constant map keeps a mapping from a ConstOp defined above to // a sunk clone of it. This allows for reusing a sunk constant with // multiple uses in the region. llvm::DenseMap sunk_constant; - Region &body = launch.body(); + Region &body = cluster.body(); visitUsedValuesDefinedAbove(body, [&](OpOperand *use) { Value constant = use->get(); auto const_op = dyn_cast_or_null(constant.getDefiningOp()); @@ -84,7 +84,7 @@ class ExecutorConstantSinking static mlir::PassRegistration pass( "tf-device-constant-sinking", - "Sink constants implicitly captured in a tf_device.launch region. This " + "Sink constants implicitly captured in a tf_device.cluster region. This " "reduces the number of arguments when outlining later."); } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc new file mode 100644 index 00000000000..786c4b74b34 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h" + +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +namespace { + +struct FuseParallelMapAndBatch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BatchDatasetV2Op op, + PatternRewriter &rewriter) const override { + auto batchInputDataset = op.input_dataset(); + + ParallelMapDatasetOp batchInputOp = dyn_cast_or_null( + batchInputDataset.getDefiningOp()); + if (!batchInputOp) return failure(); + + // The type of the `num_parallel_calls` argument in ParallelMapDataset + // and MapAndBatchDataset is different (int32 and int64 respectively) + auto num_parallel_calls_op = rewriter.create( + op.getLoc(), UnrankedTensorType::get(rewriter.getIntegerType(64)), + batchInputOp.num_parallel_calls(), rewriter.getBoolAttr(false)); + + auto fused_op = rewriter.create( + op.getLoc(), op.getType(), batchInputOp.input_dataset(), + batchInputOp.other_arguments(), op.batch_size(), + num_parallel_calls_op.y(), op.drop_remainder(), batchInputOp.f(), + op.output_types(), op.output_shapes(), + batchInputOp.preserve_cardinality()); + rewriter.replaceOp(op, {fused_op.handle()}); + return failure(); + } +}; + +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_tf_data_optimization.inc" +} // namespace + +void PopulateTFDataOptimizationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert(context); + populateWithGenerated(context, patterns); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h new file mode 100644 index 00000000000..ffbc06a9515 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Populates patterns to perform optimizations specific to tf.data operations. +void PopulateTFDataOptimizationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td new file mode 100644 index 00000000000..4b4239679b2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.td @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// TODO(jpienaar): Move this somewhere general. +class GetI64ScalarElementsAttr : + NativeCodeCall<"DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getIntegerType(64)), " # value # ")">; + +def FuseMapAndBatch : Pat< + (TF_BatchDatasetV2Op + (TF_MapDatasetOp $input_dataset, $other_arguments, $f, $output_types, + $output_shapes, $use_inter_op_parallelism, $preserve_cardinality), + $batch_size, $drop_remainder, $parallel_copy, $batch_output_types, + $batch_output_shapes), + (TF_MapAndBatchDatasetOp $input_dataset, $other_arguments, $batch_size, + (TF_ConstOp (GetI64ScalarElementsAttr<1>)), $drop_remainder, $f, + $batch_output_types, $batch_output_shapes, $preserve_cardinality)>; + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc new file mode 100644 index 00000000000..5be69bddb11 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization_pass.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h" + +namespace mlir { +namespace TF { +namespace { + +// Perform tf.data optimizations. +struct TFDataOptimization + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + mlir::TF::PopulateTFDataOptimizationPatterns(&getContext(), &patterns); + + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace +} // namespace TF +} // namespace mlir + +static mlir::PassRegistration pass( + "tf-data-optimization", "Performs tf.data optimizations"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 0571701413a..6ea6df38568 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ // This transformation pass takes ops with the same `_tpu_replicate` attribute -// in a block and clusters them together under a `tf_device::LaunchOp`. +// in a block and clusters them together under a `tf_device.cluster`. // Associated TPUReplicateMetadata ops are removed and its attributes are copied -// over to the associated `tf_device::LaunchOp`. If a cluster should be +// over to the associated `tf_device.cluster`. If a cluster should be // replicated, the associated `tf_device::LaunchOp` will be wrapped further with // a `tf_device.replicate`. This pass also assumes ops of the same cluster do // not have ops outside of the cluster that are both operands and results of the @@ -179,7 +179,7 @@ llvm::SmallSetVector CollectClusterPrecedingUsers( // Collects results and associated types of the cluster that are used outside of // the cluster. These results and types are used to create the clusters -// `tf_device::LaunchOp` and associated terminator. Results that have no uses +// `tf_device.cluster` and associated terminator. Results that have no uses // outside of the cluster (i.e. results of ops in the cluster are only consumed // by other ops in the cluster) are pruned. llvm::SmallVector CollectClusterResults( @@ -201,40 +201,37 @@ llvm::SmallVector CollectClusterResults( return results; } -// Creates a `tf_device::LaunchOp` to wrap cluster ops. -tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op, - llvm::ArrayRef results) { - // `tf_device::LaunchOp` will be placed at where the last op of the cluster - // is. +// Creates a `tf_device.cluster` to wrap cluster ops. +tf_device::ClusterOp CreateOpForCluster(Operation* last_cluster_op, + llvm::ArrayRef results) { + // `tf_device.cluster` will be placed at where the last op of the cluster is. OpBuilder builder(last_cluster_op); llvm::SmallVector result_types; for (Value result : results) result_types.push_back(result.getType()); - // An empty string placeholder is used for the device as that will be later - // populated with the device of the associated TPUReplicateMetadata op. - auto launch_op = builder.create( - last_cluster_op->getLoc(), builder.getStringAttr(""), result_types); + auto cluster = builder.create(last_cluster_op->getLoc(), + result_types); - launch_op.body().push_back(new Block); + cluster.body().push_back(new Block); // Add terminator. - builder.setInsertionPointToEnd(&launch_op.GetBody()); + builder.setInsertionPointToEnd(&cluster.GetBody()); builder.create(last_cluster_op->getLoc(), results); - return launch_op; + return cluster; } -// Moves cluster ops to associated `tf_device.LaunchOp` body. -void MoveClusterOpsToLaunchOp( - tf_device::LaunchOp launch_op, +// Moves cluster ops to associated `tf_device.cluster` body. +void MoveClusterOpsToCluster( + tf_device::ClusterOp cluster, const llvm::SmallSetVector& cluster_ops) { - MLIRContext* context = launch_op.getContext(); - Operation* terminator = &launch_op.GetBody().back(); + MLIRContext* context = cluster.getContext(); + Operation* terminator = cluster.GetBody().getTerminator(); for (Operation* cluster_op : cluster_ops) { // Remove `_tpu_replicate` and `device` attribute from ops in the cluster - // as that information will be present in the `tf_device.LaunchOp`. + // as that information will be present in the `tf_device.cluster`. cluster_op->removeAttr(Identifier::get(kTPUReplicateAttr, context)); cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); cluster_op->moveBefore(terminator); @@ -242,24 +239,24 @@ void MoveClusterOpsToLaunchOp( } // Replaces uses of cluster ops results outside of cluster with the associated -// `tf_device::LaunchOp` results. -void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op, - llvm::ArrayRef results) { - Block& launch_op_block = launch_op.GetBody(); - for (auto ret_vals : llvm::zip(results, launch_op.getResults())) { +// `tf_device.cluster` results. +void UpdateClusterResultExternalUses(tf_device::ClusterOp cluster, + llvm::ArrayRef results) { + Block& cluster_block = cluster.GetBody(); + for (auto ret_vals : llvm::zip(results, cluster.getResults())) { Value old_ret = std::get<0>(ret_vals); Value new_ret = std::get<1>(ret_vals); for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) - if (!launch_op_block.findAncestorOpInBlock(*use.getOwner())) + if (!cluster_block.findAncestorOpInBlock(*use.getOwner())) use.set(new_ret); } } // Moves users of cluster that are before the cluster to after the cluster. -void MovePrecedingClusterUsers(tf_device::LaunchOp launch_op, +void MovePrecedingClusterUsers(tf_device::ClusterOp cluster, llvm::ArrayRef preceding_users) { - Operation* op_after_launch_op = launch_op.getOperation()->getNextNode(); - for (Operation* user : preceding_users) user->moveBefore(op_after_launch_op); + Operation* op_after_cluster = cluster.getOperation()->getNextNode(); + for (Operation* user : preceding_users) user->moveBefore(op_after_cluster); } // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index` @@ -297,19 +294,18 @@ LogicalResult SortTPUReplicatedInputsByIndex( // Creates a `tf_device.replicate` to represent replication for the cluster, if // necessary. -LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, - int num_replicas) { +LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { // No need to replicate. if (num_replicas == 1) return success(); if (num_replicas < 1) - return launch_op.emitError() << "requires '" << kNumReplicasAttr - << "' int attribute to be at least 1"; + return cluster.emitError() << "requires '" << kNumReplicasAttr + << "' int attribute to be at least 1"; // Collect all used TPUReplicatedInput ops and sort by `index`. llvm::SmallSetVector unique_replicated_input_ops; mlir::visitUsedValuesDefinedAbove( - launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) { + cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) { Operation* def = operand->get().getDefiningOp(); if (def && llvm::isa(def)) unique_replicated_input_ops.insert(def); @@ -339,24 +335,24 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, } // Create replicate op. - OpBuilder builder(launch_op); + OpBuilder builder(cluster); auto replicate_op = builder.create( - launch_op.getLoc(), num_replicas, + cluster.getLoc(), num_replicas, llvm::SmallDenseMap>(), - replicated_inputs, launch_op.getResultTypes()); + replicated_inputs, cluster.getResultTypes()); if (!mirrored_variable_indices.empty()) replicate_op.setAttr(kMirroredVariableIndicesAttr, builder.getI64ArrayAttr(mirrored_variable_indices)); // Replace replicated cluster results with replicate op results. - for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) { + for (auto result_and_idx : llvm::enumerate(cluster.getResults())) { Value result = result_and_idx.value(); int idx = result_and_idx.index(); for (auto& use : result.getUses()) { Operation* def = use.getOwner(); if (!def || !llvm::isa(def)) - return launch_op.emitError() - << "requires output of " << launch_op.getOperationName() + return cluster.emitError() + << "requires output of " << cluster.getOperationName() << " to lead to a 'tf.TPUReplicatedOutput' op"; if (def->getNumResults() != num_replicas) @@ -375,14 +371,15 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, Operation* input = std::get<0>(input_and_block_arg); Value block_arg = std::get<1>(input_and_block_arg); mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg, - launch_op.body()); + cluster.body()); } - // Create terminator for replicate op and move launch into replicate. + // Create terminator for replicate op and move `tf_device.cluster` into + // replicate. builder.setInsertionPointToEnd(&replicate_op.GetBody()); auto return_op = builder.create(replicate_op.getLoc(), - launch_op.getResults()); - launch_op.getOperation()->moveBefore(return_op); + cluster.getResults()); + cluster.getOperation()->moveBefore(return_op); return success(); } @@ -396,31 +393,33 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op, // `_tpu_replicate` attribute. // 2. Find users not in cluster that are interleaved between cluster ops. // 3. Find external uses of cluster ops. -// 4. Create `tf_device::LaunchOp` with results consisting of the external -// uses of cluster ops determined at 3. -// 5. Move cluster ops to `tf_device::LaunchOp` body. -// 6. Replace external uses of cluster ops uses with `tf_device::LaunchOp` +// 4. Create `tf_device.cluster` with results consisting of the external uses +// of cluster ops determined at 3. +// 5. Move cluster ops to `tf_device.cluster` body. +// 6. Replace external uses of cluster ops uses with `tf_device.cluster` // results. -// 7. Move users from 2 to after the `tf_device::LaunchOp`. -// 8. Wrap cluster (`tf_device::LaunchOp`) in a `tf_device.replicate` if +// 7. Move users from 2 to after the `tf_device.cluster`. +// 8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if // attribute `num_replicas` is greater than 1. -// 9. Copy over TPUReplicateMetadata attributes to `tf_device::LaunchOp`. +// 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`. LogicalResult FormClustersInBlock(Block* block, const MetadataMap& metadata_map) { ClusterMap clusters; LogicalResult result = CollectAndGroupClusterOps(block, &clusters); if (failed(result)) return result; - for (const auto& cluster : clusters) { - const auto& cluster_ops = cluster.getSecond(); + for (const auto& cluster_metadata_and_ops : clusters) { + const auto& cluster_ops = cluster_metadata_and_ops.getSecond(); - auto cluster_metadata = metadata_map.find(cluster.getFirst()); + auto cluster_metadata = + metadata_map.find(cluster_metadata_and_ops.getFirst()); // No TPUReplicateMetadata for a `_tpu_replicate` attribute. if (cluster_metadata == metadata_map.end()) { cluster_ops.front()->emitWarning() << "TPUReplicateMetadata for associated '" << kTPUReplicateAttr - << "' attribute '" << cluster.getFirst() << "' is missing"; + << "' attribute '" << cluster_metadata_and_ops.getFirst() + << "' is missing"; continue; } @@ -430,28 +429,28 @@ LogicalResult FormClustersInBlock(Block* block, llvm::SmallVector results = CollectClusterResults(block, cluster_ops); - tf_device::LaunchOp launch_op = - CreateLaunchOpForCluster(cluster_ops.back(), results); + tf_device::ClusterOp cluster = + CreateOpForCluster(cluster_ops.back(), results); - MoveClusterOpsToLaunchOp(launch_op, cluster_ops); + MoveClusterOpsToCluster(cluster, cluster_ops); - UpdateLaunchOpResultExternalUses(launch_op, results); + UpdateClusterResultExternalUses(cluster, results); - MovePrecedingClusterUsers(launch_op, preceding_users.getArrayRef()); + MovePrecedingClusterUsers(cluster, preceding_users.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); if (!num_replicas || !num_replicas.isa()) - return launch_op.emitError() + return cluster.emitError() << "requires '" << kNumReplicasAttr << "' int attribute"; if (failed(ReplicateCluster( - launch_op, num_replicas.cast().getInt()))) + cluster, num_replicas.cast().getInt()))) return failure(); - // Copy TPUReplicateMetadata attributes to launch. - launch_op.setAttrs(cluster_metadata->second); + // Copy TPUReplicateMetadata attributes to `tf_device.cluster`. + cluster.setAttrs(cluster_metadata->second); // Exclude `num_replicas` as cluster should be replicated if necessary. - launch_op.removeAttr(kNumReplicasAttr); + cluster.removeAttr(kNumReplicasAttr); } return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc index ad80eaaf1a6..64af2eabd3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_padding_mapper.cc @@ -43,7 +43,7 @@ namespace TFTPU { constexpr char kPaddingMapAttr[] = "padding_map"; // This pass remaps and assigns padding maps to an encapsulated function's -// arguments from a `tf_device.launch_func` `padding_map` attribute. Remapping +// arguments from a `tf_device.cluster_func` `padding_map` attribute. Remapping // is from replicated input index to encapsulated function's operand index // (user). @@ -54,13 +54,13 @@ struct TPUDynamicPaddingMapper }; // Creates a mapping from replicated input index (in `tf_device.replicate` op) -// to `tf_device.launch_func` operand index. +// to `tf_device.cluster_func` operand index. llvm::SmallDenseMap GetRemappedReplicatedInputIndices( - tf_device::LaunchFuncOp launch_func, tf_device::ReplicateOp replicate) { + tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate) { Block* replicate_block = &replicate.GetBody(); llvm::SmallDenseMap remapped_indices; - for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) + for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) if (auto block_arg = operand_and_idx.value().dyn_cast()) if (block_arg.getOwner() == replicate_block) remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index(); @@ -68,11 +68,12 @@ llvm::SmallDenseMap GetRemappedReplicatedInputIndices( return remapped_indices; } -// Extracts `padding_map` from `tf_device.launch_func` and remaps the associated -// replicated input indices to the encapsulated function operand indices. An -// error will be returned if an index is not found or parsing failed. +// Extracts `padding_map` from `tf_device.cluster_func` and remaps the +// associated replicated input indices to the encapsulated function operand +// indices. An error will be returned if an index is not found or parsing +// failed. LogicalResult GetRemappedPaddings( - tf_device::LaunchFuncOp launch_func, int num_replicated_args, + tf_device::ClusterFuncOp cluster_func, int num_replicated_args, const llvm::SmallDenseMap& remapped_indices, llvm::SmallVectorImpl* remapped_paddings) { auto bad_index_msg = [num_replicated_args](int32_t index, @@ -85,12 +86,12 @@ LogicalResult GetRemappedPaddings( .str(); }; - Attribute padding_map_attr = launch_func.getAttr(kPaddingMapAttr); + Attribute padding_map_attr = cluster_func.getAttr(kPaddingMapAttr); if (!padding_map_attr) return success(); auto padding_map = padding_map_attr.dyn_cast(); if (!padding_map) - return launch_func.emitOpError() + return cluster_func.emitOpError() << "requires '" << kPaddingMapAttr << "' array attribute"; for (auto padding_attr_and_idx : llvm::enumerate(padding_map)) { @@ -98,25 +99,25 @@ LogicalResult GetRemappedPaddings( auto& padding_attr = padding_attr_and_idx.value(); auto padding = padding_attr.dyn_cast(); if (!padding) - return launch_func.emitOpError( + return cluster_func.emitOpError( llvm::formatv("bad '{0}' attribute at index {1}, not a string", kPaddingMapAttr, padding_attr_and_idx.index())); tensorflow::tpu::PaddingMap padding_proto; if (!padding_proto.ParseFromString(padding.getValue().str())) - return launch_func.emitOpError(llvm::formatv( + return cluster_func.emitOpError(llvm::formatv( "bad '{0}' attribute at index {1}, failed to parse '{2}' as " "tensorflow::tpu::PaddingMap", kPaddingMapAttr, idx, padding.getValue())); const int32_t arg_index = padding_proto.arg_index(); if (arg_index >= num_replicated_args || arg_index < 0) - return launch_func.emitOpError() + return cluster_func.emitOpError() << bad_index_msg(idx, "arg_index", arg_index); const int32_t padding_arg_index = padding_proto.padding_arg_index(); if (padding_arg_index >= num_replicated_args || padding_arg_index < 0) - return launch_func.emitOpError() + return cluster_func.emitOpError() << bad_index_msg(idx, "padding_arg_index", padding_arg_index); auto arg_index_it = remapped_indices.find(arg_index); @@ -125,7 +126,7 @@ LogicalResult GetRemappedPaddings( auto padding_arg_index_it = remapped_indices.find(padding_arg_index); if (padding_arg_index_it == remapped_indices.end()) { - launch_func.emitWarning(llvm::formatv( + cluster_func.emitWarning(llvm::formatv( "bad '{0}' attribute at index {1}, unused padding_arg_index {2}", kPaddingMapAttr, idx, padding_arg_index)); continue; @@ -169,22 +170,21 @@ void AnnotateFunctionArgumentsWithPaddings( } } -LogicalResult RemapAndAssignPaddingMaps(tf_device::LaunchFuncOp launch_func, +LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func, SymbolTable* symbol_table) { - auto replicate = - llvm::dyn_cast_or_null(launch_func.getParentOp()); + auto replicate = cluster_func.getParentOfType(); // LaunchFunc is not replicated, there will be no padding. if (!replicate) return success(); const int num_replicated_args = replicate.GetBody().getNumArguments(); - auto func = symbol_table->lookup(launch_func.func()); + auto func = symbol_table->lookup(cluster_func.func()); if (!func) return success(); llvm::SmallDenseMap remapped_indices = - GetRemappedReplicatedInputIndices(launch_func, replicate); + GetRemappedReplicatedInputIndices(cluster_func, replicate); llvm::SmallVector remapped_paddings; - if (failed(GetRemappedPaddings(launch_func, num_replicated_args, + if (failed(GetRemappedPaddings(cluster_func, num_replicated_args, remapped_indices, &remapped_paddings))) return failure(); @@ -196,8 +196,8 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::LaunchFuncOp launch_func, void TPUDynamicPaddingMapper::runOnOperation() { ModuleOp module = getOperation(); SymbolTable symbol_table(module); - module.walk([&](tf_device::LaunchFuncOp launch_func) { - RemapAndAssignPaddingMaps(launch_func, &symbol_table); + module.walk([&](tf_device::ClusterFuncOp cluster_func) { + RemapAndAssignPaddingMaps(cluster_func, &symbol_table); }); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc new file mode 100644 index 00000000000..b9e214470cd --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -0,0 +1,231 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" + +namespace mlir { +namespace TFTPU { + +// This pass extracts a CPU computation cluster with `_xla_outside_compilation` +// annotation from the head or tail of a TPU cluster. + +namespace { + +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; + +bool HasOutsideCompilationAttribute(Operation* op) { + return op->getAttrOfType(kXlaOutsideCompilationAttr) != nullptr; +} + +// Returns whether all operands of `op` are from values inside the +// `input_value_set`. +bool OpContainsOperandsFromSet(Operation* op, + const llvm::SetVector& input_value_set) { + for (auto operand : op->getOperands()) + if (input_value_set.count(operand) == 0) return false; + + return true; +} + +void RecordOutsideCompiledOpsAndUsages( + Operation* op, llvm::SmallSetVector* outside_compiled_ops, + llvm::SetVector* outside_compiled_op_usages) { + if (HasOutsideCompilationAttribute(op) && + OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) { + outside_compiled_ops->insert(op); + outside_compiled_op_usages->insert(op->getResults().begin(), + op->getResults().end()); + } +} + +// Traverses the MLIR graph and returns a set of ops that +// are connected to inputs of TPU computation and outside compiled. +void ExtractOutsideCompiledOpsConnectedToHead( + Value input_value, llvm::SetVector* values_used_in_host_cluster, + llvm::SmallSetVector* outside_compiled_ops) { + llvm::SmallSetVector parent_outside_compiled_ops_at_head; + for (auto& usage : input_value.getUses()) { + auto head_operation = usage.getOwner(); + RecordOutsideCompiledOpsAndUsages(head_operation, + &parent_outside_compiled_ops_at_head, + values_used_in_host_cluster); + } + + // Traverse the graph and find all outside compiled ops connected from + // the `input_value`. + while (!parent_outside_compiled_ops_at_head.empty()) { + llvm::SmallSetVector connected_outside_compiled_ops; + for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) { + auto op_results = head_outside_compiled_op->getOpResults(); + for (auto op_result : op_results) { + for (auto& use : op_result.getUses()) { + auto connected_op = use.getOwner(); + RecordOutsideCompiledOpsAndUsages(connected_op, + &connected_outside_compiled_ops, + values_used_in_host_cluster); + } + } + } + + outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(), + parent_outside_compiled_ops_at_head.end()); + std::swap(parent_outside_compiled_ops_at_head, + connected_outside_compiled_ops); + } +} + +// TODO(hongjunchoi): Also handle ops without inputs that are outside +// compiled. +// +// Returns set of ops that are outside compiled and are directly connected +// to inputs to the TPU computation. +llvm::SmallSetVector IdentifyOutsideCompiledOpsAtHead( + tf_device::ClusterOp tpu_cluster) { + llvm::SmallSetVector outside_compiled_at_head_ops; + llvm::SetVector values_used_in_cluster; + auto& cluster_region = tpu_cluster.body(); + getUsedValuesDefinedAbove(cluster_region, cluster_region, + values_used_in_cluster); + + auto input_value_list = llvm::to_vector<8>(values_used_in_cluster); + for (auto input_value : input_value_list) + ExtractOutsideCompiledOpsConnectedToHead( + input_value, &values_used_in_cluster, &outside_compiled_at_head_ops); + return outside_compiled_at_head_ops; +} + +// Returns output values of extracted outside compiled cluster at head that +// are used by the TPU computation. +llvm::SmallVector GetHeadExtractedClusterOutputs( + const llvm::SmallSetVector& head_outside_compiled_ops) { + llvm::SmallVector outputs; + outputs.reserve(head_outside_compiled_ops.size()); + + for (auto op : head_outside_compiled_ops) { + for (Operation* user : op->getUsers()) { + if (!head_outside_compiled_ops.count(user)) { + outputs.append(op->result_begin(), op->result_end()); + break; + } + } + } + + return outputs; +} + +// Creates new tf_device.launch op with outside compiled ops extracted +// from the head of TPU computation. +llvm::Optional IsolateHeadExtractedOpsToLaunchOp( + OpBuilder* builder, tf_device::ClusterOp cluster, + const llvm::SmallSetVector& head_outside_compiled_ops) { + if (head_outside_compiled_ops.empty()) + return llvm::Optional(); + + // Create tf_device.launch op to separate all extracted outside compiled ops + // before the tf_device.cluster. + auto output_values = + GetHeadExtractedClusterOutputs(head_outside_compiled_ops); + + llvm::SmallVector output_return_types; + output_return_types.reserve(output_values.size()); + for (auto output : output_values) + output_return_types.emplace_back(output.getType()); + + builder->setInsertionPoint(cluster); + auto host_launch_op = builder->create( + cluster.getLoc(), builder->getStringAttr(""), output_return_types); + + // Replace all usages of outside compiled ops that are used in TPU + // computation with the results of the above created launch op. + for (auto output_and_index : llvm::enumerate(output_values)) { + auto output_index = output_and_index.index(); + auto output = output_and_index.value(); + for (auto& use : output.getUses()) { + if (!head_outside_compiled_ops.count(use.getOwner())) + use.set(host_launch_op.getResult(output_index)); + } + } + + // Create terminator op for the newly created launch op. + host_launch_op.body().push_back(new Block()); + builder->setInsertionPointToEnd(&host_launch_op.GetBody()); + auto terminator = builder->create( + host_launch_op.getLoc(), output_values); + + // Move all outside compile ops from cluster op to launch op. + for (auto outside_compiled_op : head_outside_compiled_ops) + outside_compiled_op->moveBefore(terminator); + + return host_launch_op; +} + +struct TPUExtractHeadTailOutsideCompilation + : public PassWrapper> { + void runOnOperation() override; +}; + +void TPUExtractHeadTailOutsideCompilation::runOnOperation() { + // Get runtime devices information from the closest parent module. + auto module = getOperation(); + mlir::TF::RuntimeDevices devices; + if (failed(tensorflow::GetDevicesFromOp(module, &devices))) + return signalPassFailure(); + + OpBuilder builder(&getContext()); + module.walk([&](tf_device::ClusterOp cluster) { + auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster); + IsolateHeadExtractedOpsToLaunchOp(&builder, cluster, + head_outside_compiled_ops); + + // TODO(b/156030523): Update device attribute of newly created host launch + // op as well as enclosing Replicate op (if TPU computation is replicated) + // with host device names. + + // TODO(b/155115766): Implement tail outside compiled op extraction. + }); +} + +} // anonymous namespace + +std::unique_ptr> +CreateTPUExtractHeadTailOutsideCompilationPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-extract-head-tail-outside-compilation", + "Extracts TPU head or tail outside compilation to separate " + "parallel_execute."); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index 4e20cd9d64b..4281b85bd7f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -34,7 +34,7 @@ constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; constexpr char kDeviceAttr[] = "device"; // Mapping for `_xla_outside_compilation` attribute to ops of a cluster. -using ClusterMap = +using OutsideClusterMap = llvm::SmallDenseMap, 8>; // This pass extracts a CPU computation cluster with `_xla_outside_compilation` @@ -51,7 +51,8 @@ struct TPUExtractOutsideCompilation // Collects and clusters ops in `block` with the same `_xla_outside_compilation` // attribute into `clusters` This returns an error if a // `_xla_outside_compilation` attribute of an op is empty. -LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) { +LogicalResult CollectAndGroupOutsideClusterOps(Block* block, + OutsideClusterMap* clusters) { for (Operation& op : *block) { if (auto attr = op.getAttrOfType(kXlaOutsideCompilationAttr)) { if (attr.getValue().empty()) @@ -67,7 +68,7 @@ LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) { } // Moves `cluster_ops` to associated `launch_op` body. -void MoveClusterOpsToLaunchOp( +void MoveOutsideClusterOpsToLaunchOp( tf_device::LaunchOp launch_op, const llvm::SmallVector& cluster_ops) { MLIRContext* context = launch_op.getContext(); @@ -84,8 +85,8 @@ void MoveClusterOpsToLaunchOp( } // Creates a `tf_device::LaunchOp` to wrap cluster ops. -tf_device::LaunchOp CreateLaunchOpForCluster(OpBuilder* builder, - Operation* last_cluster_op) { +tf_device::LaunchOp CreateLaunchOpForOutsideCluster( + OpBuilder* builder, Operation* last_cluster_op) { // TODO(b/154363171): Set the CPU device. // An empty string placeholder is used for the device as that will be later // populated with the device of the associated TPUReplicateMetadata op. @@ -117,14 +118,14 @@ void PropagateParallelExecuteReturnToReplicate( // Creates a `parallel_execute` op in place of launch with 'clusters` and // 'launch` as regions. -void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch, - const ClusterMap& clusters) { - OpBuilder builder(launch); +void CreateParallelExecuteFromOutsideClusters( + tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) { + OpBuilder builder(tpu_cluster); // Create parallel_execute regions. The original TPU cluster computation // is the extra region. int num_regions = 1 + clusters.size(); auto parallel_execute_op = builder.create( - launch.getLoc(), num_regions, launch.results().getTypes()); + tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes()); // Move outside compilation clusters to parallel_execute regions. for (const auto& cluster : llvm::enumerate(clusters)) { @@ -134,21 +135,23 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch, parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); builder.setInsertionPointToEnd(&outside_block); tf_device::LaunchOp launch_op = - CreateLaunchOpForCluster(&builder, cluster_ops.back()); - MoveClusterOpsToLaunchOp(launch_op, cluster_ops); + CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back()); + MoveOutsideClusterOpsToLaunchOp(launch_op, cluster_ops); builder.setInsertionPointToEnd(&outside_block); // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute // regions either through communication with TPU parallel_execute regions // or modifying parallel_execute returns. - builder.create(launch.getLoc(), ArrayRef{}); + builder.create(tpu_cluster.getLoc(), + ArrayRef{}); } // Move the launch body to last parallel_execute block. Block& inside_block = parallel_execute_op.GetRegionBlockWithIndex(num_regions - 1); builder.setInsertionPointToEnd(&inside_block); - builder.create(launch.getLoc(), launch.getResults()); - launch.getOperation()->moveBefore(inside_block.getTerminator()); + builder.create(tpu_cluster.getLoc(), + tpu_cluster.getResults()); + tpu_cluster.getOperation()->moveBefore(inside_block.getTerminator()); PropagateParallelExecuteReturnToReplicate(parallel_execute_op); // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute @@ -157,17 +160,19 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch, } void TPUExtractOutsideCompilation::runOnFunction() { - auto extract_result = getFunction().walk([&](tf_device::LaunchOp launch) { - ClusterMap clusters; - if (failed(CollectAndGroupClusterOps(&launch.GetBody(), &clusters))) - return WalkResult::interrupt(); + auto extract_result = + getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { + OutsideClusterMap clusters; + if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), + &clusters))) + return WalkResult::interrupt(); - if (clusters.empty()) return WalkResult::advance(); + if (clusters.empty()) return WalkResult::advance(); - CreateParallelExecuteFromClusters(launch, clusters); + CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters); - return WalkResult::advance(); - }); + return WalkResult::advance(); + }); if (extract_result.wasInterrupted()) return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index a635fdb9a1f..f5e9da915c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -82,17 +82,17 @@ constexpr char kBadArrayElementMsg[] = constexpr char kBadArrayAttrLengthMsg[] = "bad '{0}' attribute, expected array attribute of size {1}, got size {2}"; -// Rewrites `tf_device.launch_func` operations assigned to TPU into actual TPU +// Rewrites `tf_device.cluster_func` operations assigned to TPU into actual TPU // jit-compile runtime ops. // // For example: -// %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster", func = +// %1 = "tf_device.cluster_func"(%0) {_tpu_replicate = "cluster", func = // @tpu_func} // %2 = "tf.SomeOp"(%1) // // Would become following ops (unimportant attributes, types are omitted): // %1 = "tf.Shape"(%0) -// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = ""} +// %2:2 = "tf._TPUCompileMlir"(%1) {module = ""} // "tf.TPUCompileSucceededAssert"(%2#0) // %3 = "tf.TPUExecute"(%0, %2#1) // %4 = "tf.SomeOp"(%3) @@ -165,7 +165,7 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, // Extracts device coordinates from a device assignment attribute on an op. LogicalResult GetDeviceCoordinates( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, llvm::SmallVectorImpl* device_assignment) { auto device_assignment_attr = op.getAttrOfType(kDeviceAssignmentAttr); @@ -190,9 +190,9 @@ LogicalResult GetDeviceCoordinates( } // Populates a TPUCompileMetadataProto with StepMarkerLocation from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoStepMarkerLocation( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto step_marker_location = op.getAttrOfType(kStepMarkerLocationAttr); @@ -216,9 +216,9 @@ LogicalResult SetMetadataProtoStepMarkerLocation( } // Populates a TPUCompileMetadataProto with PaddingMap from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoPaddingMap( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto padding_map = op.getAttrOfType(kPaddingMapAttr); if (!padding_map) @@ -259,9 +259,9 @@ LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name, } // Populates a TPUCompileMetadataProto with argument types and sharding from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoArgs( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto input_shardings = op.getAttrOfType(tensorflow::kInputShardingAttr); @@ -314,9 +314,9 @@ LogicalResult SetMetadataProtoArgs( } // Populates a TPUCompileMetadataProto with result sharding from a -// `tf_device::LaunchFuncOp`. +// `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoRetvals( - tf_device::LaunchFuncOp op, + tf_device::ClusterFuncOp op, tensorflow::tpu::TPUCompileMetadataProto* metadata) { auto output_shardings = op.getAttrOfType(tensorflow::kOutputShardingAttr); @@ -341,11 +341,11 @@ LogicalResult SetMetadataProtoRetvals( } // Populates a TPUCompileMetadataProto from attributes of a -// `tf_device::LaunchFuncOp`. If any necessary attributes are missing from the +// `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the // op, a failure will be returned. // TODO(lyandy): Support session handle and guaranteed consts. -LogicalResult SetMetadataProtoFromLaunchFuncOp( - tf_device::LaunchFuncOp op, int num_replicas, int num_cores_per_replica, +LogicalResult SetMetadataProtoFromClusterFuncOp( + tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica, llvm::Optional&& xla_device_assignment, tensorflow::tpu::TPUCompileMetadataProto* metadata) { metadata->set_num_replicas(num_replicas); @@ -377,7 +377,7 @@ tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc, builder->setInsertionPointToEnd(&launch.GetBody()); builder->create(loc, op->getResults()); - // Move op inside launch. + // Move op inside cluster. op->moveBefore(launch.GetBody().getTerminator()); builder->restoreInsertionPoint(insert_point); @@ -386,16 +386,16 @@ tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc, } // Create a `tf._TPUCompileMlir` that contains a MLIR module that is -// functionally equivalent to the function referenced by launch_func. +// functionally equivalent to the function referenced by cluster_func. Operation* BuildCompileOp( - tf_device::LaunchFuncOp launch_func, int num_replicas, + tf_device::ClusterFuncOp cluster_func, int num_replicas, int num_cores_per_replica, llvm::StringRef compilation_device, llvm::Optional&& xla_device_assignment, OpBuilder* builder) { // Set metadata from attributes. tensorflow::tpu::TPUCompileMetadataProto metadata; - if (failed(SetMetadataProtoFromLaunchFuncOp( - launch_func, num_replicas, num_cores_per_replica, + if (failed(SetMetadataProtoFromClusterFuncOp( + cluster_func, num_replicas, num_cores_per_replica, std::move(xla_device_assignment), &metadata))) return nullptr; @@ -405,28 +405,28 @@ Operation* BuildCompileOp( else metadata.SerializeToString(&txt_metadata); - // Build a shape op for each input to launch_func. + // Build a shape op for each input to cluster_func. // TODO(b/139377366): When shape inference is ready, we can use compile time // shape inference to get inputs that have static shapes and only use shape // ops for the rest. llvm::SmallVector compile_op_operands; - compile_op_operands.reserve(launch_func.getNumOperands()); + compile_op_operands.reserve(cluster_func.getNumOperands()); - for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) { + for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) { // Skip adding shape op for operands that have static shapes. tensorflow::PartialTensorShape shape( metadata.args(operand_and_idx.index()).shape()); if (shape.IsFullyDefined()) continue; auto shape_op = builder->create( - launch_func.getLoc(), + cluster_func.getLoc(), RankedTensorType::get({-1}, builder->getIntegerType(64)), operand_and_idx.value()); compile_op_operands.emplace_back(shape_op.getResult()); } - FlatSymbolRefAttr func_attr = launch_func.funcAttr(); - FuncOp func = launch_func.getParentOfType().lookupSymbol( + FlatSymbolRefAttr func_attr = cluster_func.funcAttr(); + FuncOp func = cluster_func.getParentOfType().lookupSymbol( func_attr.getValue()); std::string txt_module; @@ -436,7 +436,7 @@ Operation* BuildCompileOp( RankedTensorType::get({}, builder->getType()); auto compile_op = builder->create( - launch_func.getLoc(), /*compilation_status=*/result_type, /*program=*/ + cluster_func.getLoc(), /*compilation_status=*/result_type, /*program=*/ llvm::SmallVector(num_cores_per_replica, result_type), compile_op_operands, txt_module, txt_metadata); @@ -448,19 +448,20 @@ Operation* BuildCompileOp( // core, and all replica devices per core are grouped together. void AssignDevicesToReplicate( tf_device::ReplicateOp replicate, - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, OpBuilder* builder) { if (!replicate) return; - const int num_replicas = execution_devices.size(); - const int num_cores_per_replica = execution_devices.front().size(); + const int num_replicas = tpu_devices.size(); + const int num_cores_per_replica = tpu_devices.front().size(); llvm::SmallVector device_attrs; for (int core = 0; core < num_cores_per_replica; ++core) { llvm::SmallVector devices_by_core; devices_by_core.reserve(num_replicas); for (int replica = 0; replica < num_replicas; ++replica) - devices_by_core.push_back(execution_devices[replica][core]); + devices_by_core.push_back(tpu_devices[replica][core].device); device_attrs.push_back( builder->getNamedAttr(tensorflow::GetDeviceAliasForLogicalCore(core), @@ -473,18 +474,18 @@ void AssignDevicesToReplicate( // Creates a `tf.TPUExecute` op that executes TPU program. LogicalResult BuildExecuteOp( const int core_id, llvm::ArrayRef output_sharding_config, - llvm::ArrayRef inputs, tf_device::LaunchFuncOp launch_func, + llvm::ArrayRef inputs, tf_device::ClusterFuncOp cluster_func, OpBuilder* builder, TF::TPUExecuteOp* execute_op) { // TODO(b/139377366): Need to snapshot all resource variable inputs in // follow-up CLs. llvm::SmallVector output_types; auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation( - core_id, output_sharding_config, launch_func, &output_types); + core_id, output_sharding_config, cluster_func, &output_types); if (failed(result)) return failure(); - // TPUExecute has same output types as launch_func. + // TPUExecute has same output types as cluster_func. *execute_op = builder->create( - launch_func.getLoc(), output_types, inputs, + cluster_func.getLoc(), output_types, inputs, llvm::ArrayRef{}); return success(); } @@ -492,32 +493,33 @@ LogicalResult BuildExecuteOp( // Creates a tf_device.parallel_execute op that wraps TPUExecute op to // represent execution of TPU program in multiple logical cores. LogicalResult BuildParallelExecuteOp( - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, llvm::ArrayRef output_sharding_config, - Operation* compile_op, tf_device::LaunchFuncOp launch_func, + Operation* compile_op, tf_device::ClusterFuncOp cluster_func, OpBuilder* builder, tf_device::ParallelExecuteOp* parallel_execute_op) { - const int num_cores_per_replica = execution_devices.front().size(); + const int num_cores_per_replica = tpu_devices.front().size(); // parallel_execute op returns concatenated list of return values of // all its regions. // // TODO(b/149102702): Correctly map inputs to parallel_execute op via - // identifying xla_sharding op in the launch_func function. - const auto& launch_result_types = launch_func.getResultTypes(); + // identifying xla_sharding op in the cluster_func function. + const auto cluster_result_types = cluster_func.getResultTypes(); llvm::SmallVector concatenated_output_types; - concatenated_output_types.reserve(launch_result_types.size() * + concatenated_output_types.reserve(cluster_result_types.size() * num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { llvm::SmallVector output_types; auto result = tensorflow::GetOutputTypesForLogicalDeviceComputation( - core, output_sharding_config, launch_func, &output_types); + core, output_sharding_config, cluster_func, &output_types); if (failed(result)) return failure(); for (Type t : output_types) concatenated_output_types.emplace_back(t); } *parallel_execute_op = builder->create( - launch_func.getLoc(), num_cores_per_replica, concatenated_output_types); + cluster_func.getLoc(), num_cores_per_replica, concatenated_output_types); // Extract inputs for each region of the parallel_execute op. The i-th // element in the list represents the input lists to TPU computation for @@ -525,10 +527,10 @@ LogicalResult BuildParallelExecuteOp( llvm::SmallVector, 4> input_list; builder->setInsertionPoint(*parallel_execute_op); auto result = tensorflow::ExtractInputsForLogicalDevices( - num_cores_per_replica, launch_func, builder, &input_list); + num_cores_per_replica, cluster_func, builder, &input_list); if (failed(result)) return failure(); - const bool replicated = execution_devices.size() != 1; + const bool replicated = tpu_devices.size() != 1; // For each logical core, create a region with TPUExecute op. assert(input_list.size() == num_cores_per_replica); for (int core = 0; core < num_cores_per_replica; ++core) { @@ -539,13 +541,13 @@ LogicalResult BuildParallelExecuteOp( // // TODO(b/148913294): Identify inputs/return values specific to each // logical core TPU execution by parsing xla_sharding op in - // launch_func. + // cluster_func. auto execute_inputs = input_list[core]; execute_inputs.emplace_back(compile_op->getResult(core + 1)); TF::TPUExecuteOp execute; result = BuildExecuteOp(core, output_sharding_config, execute_inputs, - launch_func, builder, &execute); + cluster_func, builder, &execute); if (failed(result)) return failure(); // If computation is replicated, use aliased device. Otherwise there is only @@ -553,7 +555,7 @@ LogicalResult BuildParallelExecuteOp( // op. std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(core) - : execution_devices.front()[core]; + : tpu_devices.front()[core].device; auto region_launch_op = WrapOpInLaunch(builder, region.getParent()->getLoc(), execute, device); @@ -566,13 +568,14 @@ LogicalResult BuildParallelExecuteOp( } tf_device::LaunchOp AssignDevicesToReplicatedExecute( - llvm::ArrayRef> execution_devices, + llvm::ArrayRef> + tpu_devices, Operation* execute_op, OpBuilder* builder) { - const bool replicated = execution_devices.size() != 1; + const bool replicated = tpu_devices.size() != 1; // If computation is replicated, use aliased device. Otherwise there is only // one execution device and the device is assigned to the execute op. std::string device = replicated ? tensorflow::GetDeviceAliasForLogicalCore(0) - : execution_devices.front().front(); + : tpu_devices.front().front().device; return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device); } @@ -587,16 +590,16 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device); } -// Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime -// Operations that jit-compiles and executes function in `tf_device.launch_func` -// on TPU. Device assignment is determined from available devices in `devices`. -// If it is not possible to rewrite the operation or device assignment fails, a -// failure will be returned. +// Rewrites a `tf_device.cluster_func` operation into a set of TPU Runtime +// Operations that jit-compiles and executes function in +// `tf_device.cluster_func` on TPU. Device assignment is determined from +// available devices in `devices`. If it is not possible to rewrite the +// operation or device assignment fails, a failure will be returned. // -// For example, a non replicated `tf_device.launch_func`: +// For example, a non replicated `tf_device.cluster_func`: // // func @main(%arg0: tensor) { -// %0 = "tf_device.launch_func"(%arg0) +// %0 = "tf_device.cluster_func"(%arg0) // {_tpu_replicate = "cluster0", device = "", func = @_func} : // (tensor) -> tensor // return @@ -613,12 +616,12 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // return // } // -// and a replicated `tf_device.launch_func`: +// and a replicated `tf_device.cluster_func`: // // func @main(%arg0: tensor, %arg1: tensor) { // %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor) // {n = 2 : i32} { -// %1 = "tf_device.launch_func"(%ri) +// %1 = "tf_device.cluster_func"(%ri) // {_tpu_replicate = "cluster0", device = "", func = @_func} : // (tensor) -> tensor // tf_device.return %1 : tensor @@ -641,36 +644,37 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // return // } LogicalResult Rewrite( - tf_device::LaunchFuncOp launch_func, + tf_device::ClusterFuncOp cluster_func, llvm::ArrayRef devices, OpBuilder* builder) { - // Skip non-tpu device launch_func. - auto replicate_attr = launch_func.getAttrOfType("_tpu_replicate"); + // Skip non-tpu device cluster_func. + auto replicate_attr = + cluster_func.getAttrOfType("_tpu_replicate"); if (!replicate_attr) return success(); // Collect `num_replicas` and `num_cores_per_replica` attributes. int num_replicas = 1; tf_device::ReplicateOp replicate = - launch_func.getParentOp() + cluster_func.getParentOp() ? llvm::dyn_cast_or_null( - launch_func.getParentOp()) + cluster_func.getParentOp()) : nullptr; if (replicate) num_replicas = replicate.n().getLimitedValue(); auto num_cores_per_replica_attr = - launch_func.getAttrOfType(kNumCoresPerReplicaAttr); + cluster_func.getAttrOfType(kNumCoresPerReplicaAttr); if (!num_cores_per_replica_attr) - return launch_func.emitOpError( + return cluster_func.emitOpError( CreateMissingAttributeMsg(kNumCoresPerReplicaAttr)); int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - auto topology_attr = launch_func.getAttrOfType(kTopologyAttr); + auto topology_attr = cluster_func.getAttrOfType(kTopologyAttr); if (!topology_attr) - return launch_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); + return cluster_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); llvm::SmallVector device_assignment; - if (failed(GetDeviceCoordinates(launch_func, &device_assignment))) + if (failed(GetDeviceCoordinates(cluster_func, &device_assignment))) return failure(); // Determine compilation and execution devices. @@ -679,15 +683,25 @@ LogicalResult Rewrite( devices, num_replicas, num_cores_per_replica, topology_attr.getValue(), device_assignment); if (!status_or_tpu_device_assignment.ok()) - return launch_func.emitError() + return cluster_func.emitError() << "error in fetching TPU compilation/execution devices: " << status_or_tpu_device_assignment.status().error_message(); // Create compile op. auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie(); - builder->setInsertionPoint(launch_func); + builder->setInsertionPoint(cluster_func); + + // Create the TPUCompileMlir and TPUCompileSucceededAssert outside of + // parallel_execute region if it exists. + if (llvm::isa(cluster_func.getParentOp())) { + // Currently, outside compilation and model parallelism are not supported + // together. + assert(num_cores_per_replica == 1); + builder->setInsertionPoint(cluster_func.getParentOp()); + } + Operation* compile_op = BuildCompileOp( - launch_func, num_replicas, num_cores_per_replica, + cluster_func, num_replicas, num_cores_per_replica, tpu_device_assignment.compilation_device, std::move(tpu_device_assignment.xla_device_assignment), builder); if (!compile_op) return failure(); @@ -696,54 +710,55 @@ LogicalResult Rewrite( // the same _tpu_replicate attribute and replace it with the result of the // compile op. This op is used as a placeholder to hook during graph creation // the other ops that are intended to consume the compile result. - Block* block = launch_func.getOperation()->getBlock(); + Block* block = cluster_func.getOperation()->getBlock(); for (auto compile_result_op : block->getOps()) compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0)); BuildTPUCompileSucceededAssertOp( compile_op, tpu_device_assignment.compilation_device, builder); - AssignDevicesToReplicate(replicate, tpu_device_assignment.execution_devices, + AssignDevicesToReplicate(replicate, tpu_device_assignment.tpu_devices, builder); llvm::SmallVector output_shardings; auto result = tensorflow::ParseAndValidateOutputSharding( - num_cores_per_replica, launch_func, &output_shardings); + num_cores_per_replica, cluster_func, &output_shardings); if (failed(result)) return failure(); + builder->setInsertionPoint(cluster_func); if (num_cores_per_replica > 1) { // For model parallelism, tf_device.parallel_execute is used to express // concurrent device execution across multiple logical devices. tf_device::ParallelExecuteOp execute_op; - result = BuildParallelExecuteOp(tpu_device_assignment.execution_devices, - output_shardings, compile_op, launch_func, + result = BuildParallelExecuteOp(tpu_device_assignment.tpu_devices, + output_shardings, compile_op, cluster_func, builder, &execute_op); if (failed(result)) return failure(); // As tf_device.parallel_execute wraps # logical cores number of TPUExecute // ops, the number of return values of parallel_execute op exceeds that of - // launch_func op. As so, each return value of parallel_execute op must be - // mapped with corresponding return value usages of launch_func. - tensorflow::RemapOutputsFromLogicalDevices(launch_func.getLoc(), - output_shardings, launch_func, + // cluster_func op. As so, each return value of parallel_execute op must be + // mapped with corresponding return value usages of cluster_func. + tensorflow::RemapOutputsFromLogicalDevices(cluster_func.getLoc(), + output_shardings, cluster_func, execute_op, builder); } else { - llvm::SmallVector execute_inputs(launch_func.getOperands()); + llvm::SmallVector execute_inputs(cluster_func.getOperands()); execute_inputs.emplace_back(compile_op->getResult(1)); TF::TPUExecuteOp execute_op; result = BuildExecuteOp( - /*core_id=*/0, output_shardings, execute_inputs, launch_func, builder, + /*core_id=*/0, output_shardings, execute_inputs, cluster_func, builder, &execute_op); if (failed(result)) return failure(); tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute( - tpu_device_assignment.execution_devices, execute_op, builder); - launch_func.replaceAllUsesWith(launch_op); + tpu_device_assignment.tpu_devices, execute_op, builder); + cluster_func.replaceAllUsesWith(launch_op); } - launch_func.erase(); + cluster_func.erase(); return success(); } @@ -754,7 +769,7 @@ void TPURewritePass::runOnOperation() { return signalPassFailure(); OpBuilder builder(&getContext()); - auto result = getOperation().walk([&](tf_device::LaunchFuncOp op) { + auto result = getOperation().walk([&](tf_device::ClusterFuncOp op) { if (failed(Rewrite(op, devices.device_names(), &builder))) return WalkResult::interrupt(); @@ -777,7 +792,7 @@ std::unique_ptr> CreateTPURewritePass() { static PassRegistration pass( "tf-tpu-rewrite", - "Rewriting `tf_device.launch_func` on TPUs into TPU runtime ops"); + "Rewriting `tf_device.cluster_func` on TPUs into TPU runtime ops"); } // namespace TFTPU } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index ce627737646..f8b6e364f55 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -47,19 +47,19 @@ struct TPUShardingIdentificationPass void runOnOperation() override; }; -// Sets `sharding_op` if `op` is XlaShardingOp or if XlaSharding op is -// adjacent to `op`. XlaSharding op may be direct user of inputs but it -// may also be followed by an Identity op and, in the case where bfloat16 -// type is used, Cast op may be added right after the input. As so, -// parse the users of the operation to access connected XlaSharding op. +// Sets `sharding_op` if `op` is XlaShardingOp or if XlaSharding op is adjacent +// to `op`. XlaSharding op may be direct user of inputs but it may also be +// followed by an Identity op and, in the case where bfloat16 type is used, Cast +// op may be added right after the input. As so, parse the users of the +// operation to access connected XlaSharding op. // -// TODO(hongjunchoi): Consider explicitly checking op patterns to detect -// sharded inputs. +// TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded +// inputs. void GetAdjacentXlaShardingOp(Operation* op, llvm::Optional* sharding_op) { - // TODO(hongjunchoi): Detect the case when sharding configuration is - // ambiguous for a single input (i.e. multiple different XlaSharding ops - // with different configuration policies are connected). + // TODO(hongjunchoi): Detect the case when sharding configuration is ambiguous + // for a single input (i.e. multiple different XlaSharding ops with different + // configuration policies are connected). if (sharding_op->hasValue()) return; if (auto sharding = llvm::dyn_cast(op)) { @@ -74,11 +74,11 @@ void GetAdjacentXlaShardingOp(Operation* op, } // Parses XlaSharding op connected to input args. If Input to -// tf_device.LaunchFunc op is of resource type, then XlaSharding op -// will be connected to following ReadVariable op. +// tf_device.ClusterFunc op is of resource type, then XlaSharding op will be +// connected to following ReadVariable op. // -// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a -// Call op or if/while op. +// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a Call op or +// If/While op. llvm::Optional ParseInputSharding(const Value& arg) { llvm::Optional parsed_sharding_op; for (auto user : arg.getUsers()) { @@ -96,8 +96,8 @@ llvm::Optional ParseInputSharding(const Value& arg) { return parsed_sharding_op.getValue()._XlaSharding(); } -// Returns the provided sharding configuration if operand of return value -// of tf_device.LaunchFunc op is directly from XlaSharding op, +// Returns the provided sharding configuration if operand of return value of +// tf_device.ClusterFunc op is directly from XlaSharding op, llvm::Optional ParseReturnValueSharding(FuncOp func, const int output_index, const OpOperand& operand) { @@ -108,16 +108,16 @@ llvm::Optional ParseReturnValueSharding(FuncOp func, return llvm::Optional(); } -// Includes information on Func op and argument index of the input value. -// This is used to trace Value that is fed into function call ops. +// Includes information on Func op and argument index of the input value. This +// is used to trace Value that is fed into function call ops. struct FunctionAndArgumentInfo { FuncOp func; int argument_index; }; -// Adds tf.PartitionedCall op or tf.StatefulPartitionedCall op to `list`. -// If `op` is a function call op, then find the func op from provided `module` -// and add the func op with `arg_index` to `list`. `list` will later be used to +// Adds tf.PartitionedCall op or tf.StatefulPartitionedCall op to `list`. If +// `op` is a function call op, then find the func op from provided `module` and +// add the func op with `arg_index` to `list`. `list` will later be used to // trace mlir::Value that is fed into (potentially nested) function call ops. void AddFunctionalOpsToList( const int arg_index, ModuleOp module, Operation* op, @@ -138,8 +138,8 @@ void AddFunctionalOpsToList( } } -// Walks the MLIR graph from `arg` and return a list of all function -// call ops to which the `arg` op is directly connected. +// Walks the MLIR graph from `arg` and return a list of all function call ops to +// which the `arg` op is directly connected. // // For example: // argument0 -> PartitionedCallOp -> StatefulPartitionedCallOp -> AddOp @@ -177,31 +177,33 @@ llvm::SmallVector ExtractFunctionsConnectedToArg( return functions_connected_to_arg; } -// Walks the graph from the arguments of the `launch_func_op` and extracts -// sharding configurations for all inputs by parsing XlaSharding op connected -// to the arguments. If argument to the `launch_func_op` directly feeds into +// Walks the graph from the arguments of the `cluster_func_op` and extracts +// sharding configurations for all inputs by parsing XlaSharding op connected to +// the arguments. If argument to the `cluster_func_op` directly feeds into // another function call op, then recursively walk the function definition to // find the connected XlaSharding op. void IdentifyXlaShardingForComputationInputs( - StringRef logical_core_0_sharding, tf_device::LaunchFuncOp launch_func_op, - FuncOp launch_function, Builder* builder) { + StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op, + FuncOp cluster_function, Builder* builder) { // Look up function definition from module. - Block& launch_function_block = launch_function.getBody().getBlocks().front(); - ModuleOp module = launch_func_op.getParentOfType(); + Block& cluster_function_block = + cluster_function.getBody().getBlocks().front(); + ModuleOp module = cluster_func_op.getParentOfType(); llvm::SmallVector sharding_for_args( - launch_function_block.getNumArguments(), logical_core_0_sharding); + cluster_function_block.getNumArguments(), logical_core_0_sharding); - // Iterate through input arguments to the entry block of tf_device.LaunchFunc. - // For input ops, look for following XlaSharding ops. XlaSharding ops can: + // Iterate through input arguments to the entry block of + // tf_device.ClusterFunc. For input ops, look for following XlaSharding ops. + // XlaSharding ops can: // 1) Directly follow the input argument if input argument has non-resource // types. // 2) Follow ReadVariableOp if the input type is of resource type. // 3) Follow IdentityOp or CastOp after above cases (1), (2). // - // Sharding configurations are added to the tf_device.LaunchFunc as an + // Sharding configurations are added to the tf_device.ClusterFunc as an // attribute and the function as an argument attribute. - for (auto& arg : launch_function_block.getArguments()) { + for (auto& arg : cluster_function_block.getArguments()) { auto arg_sharding = ParseInputSharding(arg); const int arg_index_to_tpu_computation = arg.getArgNumber(); @@ -222,25 +224,25 @@ void IdentifyXlaShardingForComputationInputs( if (arg_sharding) { sharding_for_args[arg_index_to_tpu_computation] = arg_sharding.getValue(); - launch_function.setArgAttr( + cluster_function.setArgAttr( arg_index_to_tpu_computation, kShardingAttr, builder->getStringAttr(arg_sharding.getValue())); } else { - launch_function.setArgAttr( + cluster_function.setArgAttr( arg_index_to_tpu_computation, kShardingAttr, builder->getStringAttr(logical_core_0_sharding)); } } - launch_func_op.setAttr(tensorflow::kInputShardingAttr, - builder->getStrArrayAttr(sharding_for_args)); + cluster_func_op.setAttr(tensorflow::kInputShardingAttr, + builder->getStrArrayAttr(sharding_for_args)); } // Parses XlaSharding op directly connected from the outputs of the -// `launch_func` and extract sharding configurations for outputs. +// `cluster_func` and extract sharding configurations for outputs. void IdentifyXlaShardingForComputationOutputs( StringRef logical_core_0_sharding, FuncOp func, - tf_device::LaunchFuncOp launch_func, Builder* builder) { + tf_device::ClusterFuncOp cluster_func, Builder* builder) { // By default return values from logical core 0 is used if no sharding // configuration is defined. Block& function_block = func.getBody().getBlocks().front(); @@ -250,7 +252,7 @@ void IdentifyXlaShardingForComputationOutputs( // Iterate through operands of the terminator. If the preceding op is // XlaShardingOp, then the provided sharding configuration is added to the - // tf_device.LaunchFunc as an attribute and the function as a result + // tf_device.ClusterFunc as an attribute and the function as a result // attribute. for (auto& ret : terminator->getOpOperands()) { const int index = ret.getOperandNumber(); @@ -265,35 +267,35 @@ void IdentifyXlaShardingForComputationOutputs( builder->getStringAttr(logical_core_0_sharding)); } } - launch_func.setAttr(tensorflow::kOutputShardingAttr, - builder->getStrArrayAttr(sharding_for_rets)); + cluster_func.setAttr(tensorflow::kOutputShardingAttr, + builder->getStrArrayAttr(sharding_for_rets)); } -// Extracts input/output sharding configuration of `launch_func` by parsing -// XlaSharding ops inside the `launch_func`. -void IdentifyXlaShardingForTPUComputation(Builder* builder, - tf_device::LaunchFuncOp launch_func) { +// Extracts input/output sharding configuration of `cluster_func` by parsing +// XlaSharding ops inside the `cluster_func`. +void IdentifyXlaShardingForTPUComputation( + Builder* builder, tf_device::ClusterFuncOp cluster_func) { // Look up function definition from module. - FuncOp func = launch_func.getParentOfType().lookupSymbol( - launch_func.func()); + FuncOp func = cluster_func.getParentOfType().lookupSymbol( + cluster_func.func()); - // By default inputs/outputs have maximal sharding and are assigned to - // logical core 0 if no sharding is defined. + // By default inputs/outputs have maximal sharding and are assigned to logical + // core 0 if no sharding is defined. const std::string logical_core_0_sharding = xla::sharding_builder::AssignDevice(0).SerializeAsString(); - IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, launch_func, + IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, cluster_func, func, builder); IdentifyXlaShardingForComputationOutputs(logical_core_0_sharding, func, - launch_func, builder); + cluster_func, builder); } void TPUShardingIdentificationPass::runOnOperation() { Builder builder(getOperation().getContext()); - getOperation().walk([&](tf_device::LaunchFuncOp launch_func) { - IdentifyXlaShardingForTPUComputation(&builder, launch_func); + getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) { + IdentifyXlaShardingForTPUComputation(&builder, cluster_func); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index af8b4f064dd..a613ce1f920 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -40,10 +40,10 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/Verifier.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -57,6 +57,8 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" @@ -65,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -109,6 +112,7 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) { } namespace tensorflow { +using mlir::NamedAttrList; using mlir::TensorType; using mlir::TF::VarHandleOp; using mlir::tf_saved_model::GlobalTensorOp; @@ -306,9 +310,9 @@ class ImporterBase { // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar}, // {base_name.k2 : rfc}}. - Status ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes); + Status ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes); // Helper to create either a tf_executor operation or a TF operation wrapped // in an island. When convert_to_legacy_call is true, converts the operation @@ -1089,9 +1093,9 @@ StatusOr ImporterBase::ConvertSubtypes( return subtypes; } -Status ImporterBase::ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes) { +Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes) { TF_ASSIGN_OR_RETURN(auto func_attr, ConvertFunctionCallName(value.func().name())); attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); @@ -1817,6 +1821,8 @@ Status ImporterBase::ConvertNode(const Node& node) { absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) { if (e1->IsControlEdge() && !e2->IsControlEdge()) return false; if (!e1->IsControlEdge() && e2->IsControlEdge()) return true; + if (e1->IsControlEdge() && e2->IsControlEdge()) + return e1->src()->id() < e2->src()->id(); return e1->dst_input() < e2->dst_input(); }); @@ -2426,8 +2432,8 @@ class SavedModelObjectGraphImporter : public ImporterBase { // Main entry point: converts all functions in the given meta graph to an MLIR // Module. static StatusOr Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes); + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes); private: explicit SavedModelObjectGraphImporter( @@ -3127,8 +3133,8 @@ Status CreateSavedModelIR( } StatusOr SavedModelObjectGraphImporter::Convert( - SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, bool add_default_attributes) { + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, bool add_default_attributes) { GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = saved_model->debug_info() ? *saved_model->debug_info() : dummy_debug_info; @@ -3205,17 +3211,20 @@ class SavedModelSignatureDefImporter { public: // Main entry point: converts all functions (specified by SignatureDefs) in // the given meta graph to an MLIR Module. - static StatusOr Convert(const SavedModelBundle& bundle, - mlir::MLIRContext* context) { - SavedModelSignatureDefImporter importer(bundle, context); + static StatusOr Convert( + const SavedModelBundle& bundle, absl::Span exported_names, + mlir::MLIRContext* context) { + SavedModelSignatureDefImporter importer(bundle, exported_names, context); return importer.ConvertSignatures(); } private: SavedModelSignatureDefImporter(const SavedModelBundle& bundle, + absl::Span exported_names, mlir::MLIRContext* context) : bundle_(bundle), + exported_names_(exported_names), module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function @@ -3248,6 +3257,7 @@ class SavedModelSignatureDefImporter { const std::vector>& inputs); const SavedModelBundle& bundle_; + absl::Span exported_names_; mlir::OwningModuleRef module_; }; @@ -3263,6 +3273,9 @@ SavedModelSignatureDefImporter::ConvertSignatures() { GraphDebugInfo debug_info; if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; + llvm::StringSet<> exported_name_set; + exported_name_set.insert(exported_names_.begin(), exported_names_.end()); + for (const auto& key_and_signature_def : signatures) { const std::string& sig_def_key = key_and_signature_def.first; const SignatureDef& signature_def = key_and_signature_def.second; @@ -3272,6 +3285,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() { if (sig_def_key == "__saved_model_init_op") { continue; } + if (!exported_name_set.empty() && + exported_name_set.count(sig_def_key) == 0) { + continue; + } TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, debug_info, flib_def)); @@ -3554,12 +3571,14 @@ StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { return SavedModelObjectGraphImporter::Convert( - saved_model, context, exported_names, add_default_attributes); + saved_model, exported_names, context, add_default_attributes); } StatusOr ConvertSavedModelV1ToMlir( - const SavedModelBundle& saved_model, mlir::MLIRContext* context) { - return SavedModelSignatureDefImporter::Convert(saved_model, context); + const SavedModelBundle& saved_model, absl::Span exported_names, + mlir::MLIRContext* context) { + return SavedModelSignatureDefImporter::Convert(saved_model, exported_names, + context); } std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 8603eadb487..bdb72345201 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -55,6 +55,7 @@ stream_executor::port::StatusOr ConvertSavedModelToMlir( // expressed with tf_executor dialect. stream_executor::port::StatusOr ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, + absl::Span exported_names, mlir::MLIRContext* context); // Serialize a MLIR module to a string. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index f4d3ff443a0..cb3a3be22d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" -#include "mlir/Analysis/Verifier.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 2c7f84d8268..6ada0fec4e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -141,7 +141,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context) { + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context) { tensorflow::SavedModelBundle bundle; tensorflow::SessionOptions session_options; // Force saved model states to be restored to CPU. @@ -155,7 +156,7 @@ mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( return nullptr; } - auto module_or = ConvertSavedModelV1ToMlir(bundle, context); + auto module_or = ConvertSavedModelV1ToMlir(bundle, exported_names, context); if (!module_or.status().ok()) { LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index f498864c8aa..490b7c7d8f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -64,7 +64,8 @@ mlir::OwningModuleRef SavedModelObjectGraphToMlirImport( // given MLIR `context`. mlir::OwningModuleRef SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, - const std::unordered_set& tags, mlir::MLIRContext* context); + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index 8212c0b50a4..06805e633e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -52,4 +52,11 @@ void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass, Log(print_callback, pass, operation, "after"); } +void BridgeTimingConfig::printTiming(PrintCallbackFn printCallback) { + std::string name = "mlir_bridge_pass_timing.txt"; + std::unique_ptr os; + std::string filepath; + if (CreateFileForDumping(name, &os, &filepath).ok()) printCallback(*os); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h index b5b2ad33b31..eaf3a7c2598 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -44,6 +44,13 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { PrintCallbackFn print_callback) override; }; +// Logger for logging/dumping pass pipeline timings after completion. +class BridgeTimingConfig : public mlir::PassManager::PassTimingConfig { + public: + // Hook that control how/where is the output produced + void printTiming(PrintCallbackFn printCallback) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index b891682366b..e8ca691f961 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -17,10 +17,13 @@ limitations under the License. #include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -35,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -289,6 +293,12 @@ Status ConvertMLIRToXlaComputation( tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type)); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); + // Run shape inference pass to propagate shapes through tensor_cast operations + // from static to dynamic shapes. This could be generated if the shape + // inference was originally missing in a TF op but the corresponding HLO op + // had static shape after lowering. + tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); + // Run LegalizeTFPass again because the previous legalization passes can // expose more graph pruning and canonicalization opportunities that are // necessary for the second LegalizeTFPass(allow_partial_conversion=false) @@ -299,7 +309,7 @@ Status ConvertMLIRToXlaComputation( if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling // multi-threading as well. - tf2xla.disableMultithreading(); + module_op.getContext()->disableMultithreading(); tf2xla.enableIRPrinting(std::make_unique( /*print_module_scope=*/true)); } @@ -393,14 +403,47 @@ Status CompileSerializedMlirToXlaHlo( std::move(custom_legalization_passes)); } +// Rewrites the given module with specified args. For each of the constant args, +// it gets inlined in the "main' function and the corresponding argument is +// removed from the signature. +// Returns the original indices for the other arguments on success. +static StatusOr> RewriteWithArgs( + mlir::ModuleOp module, llvm::ArrayRef args) { + mlir::FuncOp main_fn = module.lookupSymbol("main"); + std::vector params; + + auto builder = mlir::OpBuilder(main_fn.getBody()); + std::vector args_to_erase; + for (int idx = 0; idx < args.size(); idx++) { + const XlaCompiler::Argument& xla_arg = args[idx]; + mlir::BlockArgument mlir_arg = main_fn.getArgument(idx); + if (xla_arg.kind != XlaCompiler::Argument::kConstant) { + params.push_back(idx); + continue; + } + + TF_ASSIGN_OR_RETURN(auto value_attr, + ConvertTensor(xla_arg.constant_value, &builder)); + // TODO(hinsu): Use the actual location of the constant. + auto constant = builder.create( + mlir::UnknownLoc::get(module.getContext()), value_attr); + mlir_arg.replaceAllUsesWith(constant); + args_to_erase.push_back(idx); + } + + for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx); + return params; +} + Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result, std::vector> custom_legalization_passes) { RegisterDialects(); + mlir::MLIRContext context; GraphImportConfig config; config.graph_as_function = true; @@ -408,10 +451,19 @@ Status CompileGraphToXlaHlo( ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); if (!module_or.ok()) return module_or.status(); - return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, - device_type, use_tuple_args, - shape_representation_fn, compilation_result, - std::move(custom_legalization_passes)); + mlir::ModuleOp module = module_or.ValueOrDie().get(); + TF_ASSIGN_OR_RETURN(std::vector remaining_params, + RewriteWithArgs(module, {args.data(), args.size()})); + llvm::SmallVector arg_shapes; + arg_shapes.reserve(args.size()); + for (unsigned idx : remaining_params) + arg_shapes.push_back(absl::get(args[idx].shape)); + + auto status = CompileMlirToXlaHlo( + module, arg_shapes, device_type, use_tuple_args, shape_representation_fn, + compilation_result, std::move(custom_legalization_passes)); + compilation_result->input_mapping = remaining_params; + return status; } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 0218efb83c6..24b60dcb346 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -71,7 +71,7 @@ Status CompileSerializedMlirToXlaHlo( // Same as the above but takes input as TensorFlow Graph. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 118af434629..91640aff437 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -455,8 +455,12 @@ TEST(CompileGraphToXlaHlo, Basic) { test::graph::Retval(&graph, 0, arg); XlaCompiler::CompilationResult result; + XlaCompiler::Argument compiler_arg; + compiler_arg.kind = XlaCompiler::Argument::kParameter; + compiler_arg.shape = TensorShape(); + TF_ASSERT_OK( - CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT", + CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT", /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 1c1d127d42f..b28f26b6c3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -31,12 +31,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -85,16 +87,22 @@ StatusOr ConvertFlatTensor(const Tensor& input_tensor, type, llvm::makeArrayRef(arr.data(), arr.size())); } -StatusOr ConvertBF16Tensor(const Tensor& input_tensor, - ShapedType type) { +ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor, + RankedTensorType type) { auto flat = input_tensor.flat(); + llvm::SmallVector floats; + floats.reserve(flat.size()); + for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size())) + floats.push_back(llvm::APFloat(static_cast(v))); + return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(floats)); +} - llvm::SmallVector flat_double; - flat_double.reserve(flat.size()); - for (bfloat16 v : llvm::makeArrayRef(flat.data(), flat.size())) { - flat_double.push_back(static_cast(v)); - } - return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(flat_double)); +ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) { + auto buffer = llvm::makeArrayRef(static_cast(tensor.data()), + tensor.TotalBytes()); + return mlir::DenseElementsAttr::getFromRawBuffer( + type, buffer, + /*isSplatBuffer=*/type.getNumElements() == 1); } StatusOr ConvertStringTensor(const Tensor& input_tensor, @@ -125,18 +133,28 @@ StatusOr ConvertTensor(const Tensor& input_tensor, case DTYPE: \ return ConvertFlatTensor(input_tensor, type); - // TODO(fengliuai): customize the conversions for more types. + // TODO(fengliuai): customize the conversions for quantized and string types. switch (input_dtype) { CONVERT_FLAT(DT_BOOL, bool) CONVERT_FLAT(DT_FLOAT, float) CONVERT_FLAT(DT_DOUBLE, double) + CONVERT_FLAT(DT_INT8, int8) + CONVERT_FLAT(DT_INT16, int16) CONVERT_FLAT(DT_INT32, int32) CONVERT_FLAT(DT_INT64, int64) + CONVERT_FLAT(DT_UINT8, uint8) + CONVERT_FLAT(DT_UINT16, uint16) + CONVERT_FLAT(DT_UINT32, uint32) + CONVERT_FLAT(DT_UINT64, uint64) + CONVERT_FLAT(DT_COMPLEX64, std::complex) + CONVERT_FLAT(DT_COMPLEX128, std::complex) // BFLOAT16 is a special case that it needs to be cast to double type to // match its storage type. case DT_BFLOAT16: - return ConvertBF16Tensor(input_tensor, type); + return ConvertBf16Tensor(input_tensor, type); + case DT_HALF: + return ConvertHalfTensor(input_tensor, type); case DT_STRING: return ConvertStringTensor(input_tensor, type); @@ -199,12 +217,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { // Converts an MLIR dense string elements attribute to a TensorFlow tensor // proto. -Status ConvertStringElementsAttr(const DenseStringElementsAttr attr, - TensorProto* output_tensor) { - for (const auto& val : attr.getRawStringData()) { - output_tensor->add_string_val(val.data(), val.size()); +void ConvertStringElementsAttr( + const DenseStringElementsAttr attr, + protobuf::RepeatedPtrField* output) { + for (const auto& val : attr.getRawStringData()) + output->Add({val.data(), val.size()}); +} + +template +void ConvertComplexElementsAttr(const mlir::DenseElementsAttr attr, + protobuf::RepeatedField* output) { + for (const auto& val : attr.getValues>()) { + output->Add(val.real()); + output->Add(val.imag()); } - return Status::OK(); } // Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. @@ -218,139 +244,80 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr, return InvalidArgument("Unexpected elements attribute type from MLIR."); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the double_val field updated. -Status ConvertDoubleElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_double_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_double_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the float_val field updated. -Status ConvertFloatElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_float_val(elts.getSplatValue()); - } else { - for (auto value : elts.getValues()) - output_tensor->add_float_val(value); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the half_val field updated. -Status ConvertHalfElementsAttr(const ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_half_val( - (*elts.begin()).bitcastToAPInt().getSExtValue()); - } else { - for (const auto& value : elts.getFloatValues()) - output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int_val field updated. -Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertBfloat16ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - auto elts = attr.dyn_cast(); - if (!elts) { - return ConvertOpaqueElementsAttr(attr, output_tensor); - } - - // Bfloat16 is internally represented as `double` in MLIR. - if (elts.isSplat()) { - double v = elts.getSplatValue(); - bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); +// Converts an MLIR elements attribute and adds it to specified repeated field. +template +void ConvertElementsAttr(const mlir::DenseElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add(attr.getSplatValue()); } else { - for (auto v : elts.getValues()) { + for (auto value : attr.getValues()) output->Add(value); + } +} + +// Converts an MLIR elements attribute containing half values and adds it to +// specified repeated field. +void ConvertHalfElementsAttr(const DenseFPElementsAttr attr, + protobuf::RepeatedField* output_tensor) { + if (attr.isSplat()) { + output_tensor->Add((*attr.begin()).bitcastToAPInt().getSExtValue()); + } else { + for (const llvm::APFloat value : attr.getFloatValues()) + output_tensor->Add(value.bitcastToAPInt().getSExtValue()); + } +} + +// Converts an MLIR elements attribute containing int values and adds it to +// specified repeated field. +void ConvertIntElementsAttr(const mlir::DenseIntElementsAttr attr, + protobuf::RepeatedField* output) { + if (attr.isSplat()) { + output->Add((*attr.begin()).getSExtValue()); + } else { + for (const llvm::APInt val : attr) output->Add(val.getSExtValue()); + } +} + +void ConvertBfloat16ElementsAttr(const mlir::DenseFPElementsAttr attr, + protobuf::RepeatedField* output) { + // Bfloat16 is internally represented as `double` in MLIR. + if (attr.isSplat()) { + double v = attr.getSplatValue(); + bfloat16 bf16_val = static_cast(v); + output->Add(absl::bit_cast(bf16_val)); + } else { + for (auto v : attr.getValues()) { bfloat16 bf16_val = static_cast(v); - output_tensor->add_half_val(absl::bit_cast(bf16_val)); + output->Add(absl::bit_cast(bf16_val)); } } - - return Status::OK(); } -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with the int64_val field updated. -Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - if (elts.isSplat()) { - output_tensor->add_int64_val((*elts.begin()).getSExtValue()); - } else { - for (const auto& val : elts) - output_tensor->add_int64_val(val.getSExtValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -// Converts an MLIR elements attribute to a TensorFlow tensor proto -// with bool_val field updated. -Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, - TensorProto* output_tensor) { - if (auto elts = attr.dyn_cast()) { - for (const auto& val : elts) { - output_tensor->add_bool_val(val.getBoolValue()); - } - return Status::OK(); - } - return ConvertOpaqueElementsAttr(attr, output_tensor); -} - -Status ConvertToTensorProto(const ElementsAttr attr, - TensorProto* output_tensor) { +Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { auto type = attr.getType(); auto shape = type.getShape(); DataType output_dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &output_dtype)); - output_tensor->set_dtype(output_dtype); - ConvertToTensorShapeProto(shape, output_tensor->mutable_tensor_shape()); + output->set_dtype(output_dtype); + ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); + + if (attr.isa()) + return ConvertOpaqueElementsAttr(attr.cast(), output); + + auto dense_attr = attr.dyn_cast(); + if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); switch (output_dtype) { case DT_FLOAT: - return ConvertFloatElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_float_val()); + break; case DT_HALF: - // Handles both DenseFPElementsAttr and OpaqueElementsAttr. - return ConvertHalfElementsAttr(attr, output_tensor); + ConvertHalfElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; case DT_DOUBLE: - return ConvertDoubleElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_double_val()); + break; case DT_QUINT8: case DT_UINT8: case DT_INT8: @@ -358,20 +325,40 @@ Status ConvertToTensorProto(const ElementsAttr attr, case DT_UINT16: case DT_INT16: case DT_INT32: - return ConvertIntElementsAttr(attr, output_tensor); + ConvertIntElementsAttr(dense_attr.cast(), + output->mutable_int_val()); + break; + case DT_UINT32: + ConvertElementsAttr(dense_attr, output->mutable_uint32_val()); + break; + case DT_UINT64: + ConvertElementsAttr(dense_attr, output->mutable_uint64_val()); + break; case DT_INT64: - return ConvertInt64ElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_int64_val()); + break; case DT_BOOL: - return ConvertBoolElementsAttr(attr, output_tensor); + ConvertElementsAttr(dense_attr, output->mutable_bool_val()); + break; case DT_BFLOAT16: - return ConvertBfloat16ElementsAttr(attr, output_tensor); + ConvertBfloat16ElementsAttr(dense_attr.cast(), + output->mutable_half_val()); + break; case DT_STRING: - return ConvertStringElementsAttr(attr.cast(), - output_tensor); + ConvertStringElementsAttr(dense_attr.cast(), + output->mutable_string_val()); + break; + case DT_COMPLEX64: + ConvertComplexElementsAttr(dense_attr, output->mutable_scomplex_val()); + break; + case DT_COMPLEX128: + ConvertComplexElementsAttr(dense_attr, output->mutable_dcomplex_val()); + break; default: - return ConvertOpaqueElementsAttr(attr.cast(), - output_tensor); + return errors::Unimplemented(absl::StrCat("Unimplemented data type ", + DataTypeString(output_dtype))); } + return Status::OK(); } Status ConvertToTensor(const mlir::ElementsAttr attr, Tensor* output_tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 673b692b4e6..bf96e3d1df4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include +#include #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -23,6 +24,8 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -79,7 +82,7 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) { mlir::Builder b(&context); // Create the sample tensor to convert. - tensorflow::Tensor tensor(DT_STRING, TensorShape({1, 2, 2, 1})); + Tensor tensor(DT_STRING, TensorShape({1, 2, 2, 1})); EXPECT_EQ(4, tensor.NumElements()); auto Tt = tensor.flat(); Tt.setValues({"one", "two", "three", "four"}); @@ -97,5 +100,75 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) { EXPECT_EQ(string_values[3], mlir::StringRef("four")); } +class ConvertTensorTest : public ::testing::Test { + protected: + template + void VerifyConversion(std::initializer_list values, DataType dtype, + mlir::Type expected_ty) { + mlir::Builder b(expected_ty.getContext()); + Tensor tensor(dtype, TensorShape({static_cast(values.size())})); + tensor.flat().setValues(values); + + auto value_or = ConvertTensor(tensor, &b); + TF_ASSERT_OK(value_or.status()); + auto attr = value_or.ValueOrDie(); + + EXPECT_EQ(attr.getType().getElementType(), expected_ty); + + Tensor out; + TF_ASSERT_OK(ConvertToTensor(attr, &out)); + + test::ExpectTensorEqual(tensor, out); + } +}; + +TEST_F(ConvertTensorTest, Simple) { + RegisterDialects(); + + mlir::MLIRContext context; + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context))); + ASSERT_NO_FATAL_FAILURE( + VerifyConversion({bfloat16(1.0), bfloat16(-1.0)}, DT_BFLOAT16, + mlir::FloatType::getBF16(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context))); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT8, mlir::IntegerType::get(8, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT16, mlir::IntegerType::get(16, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT32, mlir::IntegerType::get(32, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, -1}, DT_INT64, mlir::IntegerType::get(64, &context))); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT8, + mlir::IntegerType::get( + 8, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT16, + mlir::IntegerType::get( + 16, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT32, + mlir::IntegerType::get( + 32, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {1, 2}, DT_UINT64, + mlir::IntegerType::get( + 64, mlir::IntegerType::SignednessSemantics::Unsigned, &context))); + + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX64, + mlir::ComplexType::get(mlir::FloatType::getF32(&context)))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion>( + {{0.0, 1.0}, {1.0, 0.0}}, DT_COMPLEX128, + mlir::ComplexType::get(mlir::FloatType::getF64(&context)))); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc index ffcd1f71a50..c77107c8de7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Analysis/Verifier.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index cc795259893..4877cbc4a44 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -59,6 +59,18 @@ limitations under the License. namespace tensorflow { namespace { +// static TensorFlow op prefix set. +std::set* GlobalOpPrefixes() { + static std::set* global_op_prefixes = [] { + std::set* result = new std::set; + result->insert("tf."); + result->insert("_tf."); + result->insert("tf_executor."); + return result; + }(); + return global_op_prefixes; +} + // Converts a location to the debug information for the node def. Status ConvertLocation(mlir::Location inst_loc, NodeDef::ExperimentalDebugInfo* debug_info) { @@ -268,8 +280,10 @@ StatusOr GetTensorFlowOpName(llvm::StringRef op_name) { // - ".sink" or ".Sink": only the NextIteration operation has this suffix. We // don't need to consider ".source"/".Source" because the nodes with this // suffix are skipped by the caller and will not be added to the graph. - if (!op_name.consume_front("_tf.") && !op_name.consume_front("tf.") && - !op_name.consume_front("tf_executor.")) { + auto prefixes = GlobalOpPrefixes(); + if (std::none_of(prefixes->begin(), prefixes->end(), [&](std::string prefix) { + return op_name.consume_front(prefix); + })) { return errors::FailedPrecondition("op node '", op_name.str(), "' was not a TF op!"); } @@ -506,4 +520,9 @@ bool IsLegacyCallInstruction(mlir::Operation* inst) { inst->getName().getStringRef().compare("_tf.LegacyCall") == 0; } +Status AddTensorFlowOpPrefix(std::string prefix) { + GlobalOpPrefixes()->insert(prefix); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 32ed528bd0d..58fe39fa4e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -34,10 +34,17 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/stream_executor/lib/statusor.h" +namespace mlir { +class ShapedType; +} // namespace mlir + namespace tensorflow { using stream_executor::port::StatusOr; +// Add custom op prefix for TensorFlow dialects. +Status AddTensorFlowOpPrefix(std::string); + // Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control // dialect back into a TensorFlow valid op name. StatusOr GetTensorFlowOpName(llvm::StringRef); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index 47c5d27767d..3d16352f78e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -31,12 +31,17 @@ inline llvm::StringRef StringViewToRef(absl::string_view view) { } } // namespace -Status LoadProtoFromBuffer(absl::string_view input, - protobuf::MessageLite* proto) { +Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto) { // Attempt to parse as text. if (ParseTextProto(input, "", proto).ok()) return Status::OK(); // Else attempt to parse as binary. + return LoadProtoFromBuffer(input, static_cast(proto)); +} + +Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto) { + // Attempt to parse as binary. protobuf::io::ArrayInputStream binary_stream(input.data(), input.size()); if (proto->ParseFromZeroCopyStream(&binary_stream)) return Status::OK(); @@ -44,8 +49,8 @@ Status LoadProtoFromBuffer(absl::string_view input, return errors::InvalidArgument("Could not parse input proto"); } -Status LoadProtoFromFile(absl::string_view input_filename, - protobuf::MessageLite* proto) { +template +Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { const auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(StringViewToRef(input_filename)); if (std::error_code error = file_or_err.getError()) { @@ -60,4 +65,14 @@ Status LoadProtoFromFile(absl::string_view input_filename, return LoadProtoFromBuffer(content, proto); } +Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto) { + return LoadProtoFromFileImpl(input_filename, proto); +} + +Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto) { + return LoadProtoFromFileImpl(input_filename, proto); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h index 56cd188f393..ad1531dd449 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h @@ -24,13 +24,20 @@ namespace tensorflow { // Reads text (.pbtext) or binary (.pb) format of a proto message from the given // buffer. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. +Status LoadProtoFromBuffer(absl::string_view input, protobuf::Message* proto); Status LoadProtoFromBuffer(absl::string_view input, - tensorflow::protobuf::MessageLite* proto); + protobuf::MessageLite* proto); // Reads text (.pbtext) or binary (.pb) format of a proto message from the given // file path. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. Status LoadProtoFromFile(absl::string_view input_filename, - tensorflow::protobuf::MessageLite* proto); + protobuf::Message* proto); +Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc index b616d34fdd8..1bf615de8c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc @@ -24,7 +24,6 @@ limitations under the License. namespace tensorflow { -#ifndef TENSORFLOW_LITE_PROTOS namespace { // Error collector that simply ignores errors reported. class NoOpErrorCollector : public protobuf::io::ErrorCollector { @@ -32,7 +31,6 @@ class NoOpErrorCollector : public protobuf::io::ErrorCollector { void AddError(int line, int column, const std::string& message) override {} }; } // namespace -#endif // TENSORFLOW_LITE_PROTOS Status ConsumePrefix(absl::string_view str, absl::string_view prefix, absl::string_view* output) { @@ -45,8 +43,7 @@ Status ConsumePrefix(absl::string_view str, absl::string_view prefix, Status ParseTextProto(absl::string_view text_proto, absl::string_view prefix_to_strip, - protobuf::MessageLite* parsed_proto) { -#ifndef TENSORFLOW_LITE_PROTOS + protobuf::Message* parsed_proto) { protobuf::TextFormat::Parser parser; // Don't produce errors when attempting to parse text format as it would fail // when the input is actually a binary file. @@ -60,15 +57,11 @@ Status ParseTextProto(absl::string_view text_proto, } protobuf::io::ArrayInputStream input_stream(text_proto_without_prefix.data(), text_proto_without_prefix.size()); - if (parser.Parse(&input_stream, - tensorflow::down_cast(parsed_proto))) { + if (parser.Parse(&input_stream, parsed_proto)) { return Status::OK(); } parsed_proto->Clear(); return errors::InvalidArgument("Could not parse text proto: ", text_proto); -#else - return errors::Unavailable("Cannot parse text protos on mobile."); -#endif // TENSORFLOW_LITE_PROTOS } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h index 5646f1378af..c1f1e3b111d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h @@ -32,7 +32,12 @@ Status ConsumePrefix(absl::string_view str, absl::string_view prefix, // proto. Status ParseTextProto(absl::string_view text_proto, absl::string_view prefix_to_strip, - protobuf::MessageLite* parsed_proto); + protobuf::Message* parsed_proto); +inline Status ParseTextProto(absl::string_view /* text_proto */, + absl::string_view /* prefix_to_strip */, + protobuf::MessageLite* /* parsed_proto */) { + return errors::Unavailable("Cannot parse text protos on mobile."); +} } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 6cf2781e48d..06c10c26835 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -164,12 +164,19 @@ std::string GetTPUCompilationDevice(Device system_device) { return DeviceNameUtils::ParsedNameToString(system_device); } +// Finds the host CPU device for a given TPU device. +std::string GetCPUHostDeviceForTPUDevice(Device tpu_device) { + tpu_device.type = DEVICE_CPU; + tpu_device.id = 0; + return DeviceNameUtils::ParsedNameToString(tpu_device); +} + // Determines execution devices when topology and device assignment are not // defined. This is a special case where a single core computation is replicated // to every core in the mesh. TPU devices are simply added to // `execution_devices` of one replica. `num_replicas` must be 1 or the total // number of TPU devices available, and `num_cores_per_replica` must be 1. -StatusOr GetFullMeshTPUExecutionDeviceAssignment( +StatusOr GetFullMeshTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices) { const int num_tasks = tpu_devices.size(); @@ -185,17 +192,18 @@ StatusOr GetFullMeshTPUExecutionDeviceAssignment( "'num_cores_per_replica' must be equal to 1, got ", num_cores_per_replica); - ExecutionDevices execution_devices; - execution_devices.reserve(num_replicas); + TPUDevicesAndHosts devices_and_hosts; + devices_and_hosts.reserve(num_replicas); for (int i = 0; i < num_replicas; ++i) { const int task = i / num_tpus_per_task; const int device = i % num_tpus_per_task; - execution_devices.push_back( - {tensorflow::DeviceNameUtils::ParsedNameToString( - tpu_devices[task][device])}); + const auto& tpu_device = tpu_devices[task][device]; + devices_and_hosts.push_back({TPUDeviceAndHost( + /*device=*/tensorflow::DeviceNameUtils::ParsedNameToString(tpu_device), + /*host=*/GetCPUHostDeviceForTPUDevice(tpu_device))}); } - return execution_devices; + return devices_and_hosts; } // Helper struct for keeping track of task and device for an associated TPU @@ -326,7 +334,7 @@ StatusOr> ParseTopologyAttr( // - number of device coordinates (in tuple 3) match number 'num_replicas' * // 'num_cores_per_replica' // - a TPU device associated with each device coordinate -StatusOr> +StatusOr> GetGeneralTPUExecutionDeviceAssignment( int num_replicas, int num_cores_per_replica, llvm::ArrayRef> tpu_devices, @@ -361,9 +369,9 @@ GetGeneralTPUExecutionDeviceAssignment( std::vector used_device_ids( location_to_id(bound_x - 1, bound_y - 1, bound_z - 1, bound_core - 1), false); - ExecutionDevices execution_devices( - num_replicas, - llvm::SmallVector(num_cores_per_replica, "")); + TPUDevicesAndHosts devices_and_hosts( + num_replicas, llvm::SmallVector( + num_cores_per_replica, TPUDeviceAndHost())); xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica); int pos = 0; for (int replica = 0; replica < num_replicas; ++replica) { @@ -393,16 +401,18 @@ GetGeneralTPUExecutionDeviceAssignment( used_device_ids[device_id] = true; device_assignment(replica, logical_core) = device_id; - execution_devices[replica][logical_core] = - DeviceNameUtils::ParsedNameToString(tpu_devices[task][device]); + auto& device_and_host = devices_and_hosts[replica][logical_core]; + const auto& tpu_device = tpu_devices[task][device]; + device_and_host.device = DeviceNameUtils::ParsedNameToString(tpu_device); + device_and_host.host = GetCPUHostDeviceForTPUDevice(tpu_device); } } xla::DeviceAssignmentProto device_assignment_proto; TF_RETURN_IF_ERROR(device_assignment.Serialize(&device_assignment_proto)); - return std::pair( - std::move(execution_devices), std::move(device_assignment_proto)); + return std::pair( + std::move(devices_and_hosts), std::move(device_assignment_proto)); } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index dd296a13f4b..5fdb6b8768b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -30,29 +30,40 @@ limitations under the License. namespace tensorflow { using stream_executor::port::StatusOr; -// TPU devices to be used for execution (e.g. devices for TPUExecute ops). They -// are ordered by `num_replicas` followed by `num_cores_per_replica`. -using ExecutionDevices = - llvm::SmallVector, 8>; +// A TPU device for execution alongside its associated host CPU device. +struct TPUDeviceAndHost { + TPUDeviceAndHost() {} + TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host) + : device(device), host(host) {} -// TPU compilation device, execution devices, and optionally execution device -// IDs. Execution device IDs are populated if `topology` and `device_assignment` -// are provided. + std::string device; + std::string host; +}; + +// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and +// their associated host CPU devices (for outside compilation). They are ordered +// by `num_replicas` followed by `num_cores_per_replica`. +using TPUDevicesAndHosts = + llvm::SmallVector, 8>; + +// TPU compilation device, execution and associated host devices, and optionally +// execution device IDs. Execution device IDs are populated if `topology` and +// `device_assignment` are provided. struct TPUDeviceAssignment { TPUDeviceAssignment(llvm::StringRef compilation_device, - ExecutionDevices&& execution_devices) + TPUDevicesAndHosts&& tpu_devices) : compilation_device(compilation_device), - execution_devices(std::move(execution_devices)) {} + tpu_devices(std::move(tpu_devices)) {} TPUDeviceAssignment(llvm::StringRef compilation_device, - ExecutionDevices&& execution_devices, + TPUDevicesAndHosts&& tpu_devices, xla::DeviceAssignmentProto&& xla_device_assignment) : compilation_device(compilation_device), - execution_devices(std::move(execution_devices)), + tpu_devices(std::move(tpu_devices)), xla_device_assignment(std::move(xla_device_assignment)) {} std::string compilation_device; - ExecutionDevices execution_devices; + TPUDevicesAndHosts tpu_devices; llvm::Optional xla_device_assignment; }; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 87319f2adeb..7ac5635a6e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -323,30 +323,46 @@ TEST(TPURewriteDeviceUtilTest, ValidFullMeshDeviceAssignment) { TF_ASSERT_OK(status_or.status()); - auto& tpu_device_assignment = status_or.ValueOrDie(); + const auto& tpu_device_assignment = status_or.ValueOrDie(); EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 8); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 1); + const auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 8); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 1); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:0/device:TPU:1"); - EXPECT_EQ(execution_devices[2][0], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][0].device, "/job:worker/replica:0/task:0/device:TPU:2"); - EXPECT_EQ(execution_devices[3][0], + EXPECT_EQ(tpu_devices[2][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][0].device, "/job:worker/replica:0/task:0/device:TPU:3"); - EXPECT_EQ(execution_devices[4][0], + EXPECT_EQ(tpu_devices[3][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[4][0].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[5][0], + EXPECT_EQ(tpu_devices[4][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[5][0].device, "/job:worker/replica:0/task:1/device:TPU:1"); - EXPECT_EQ(execution_devices[6][0], + EXPECT_EQ(tpu_devices[5][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[6][0].device, "/job:worker/replica:0/task:1/device:TPU:2"); - EXPECT_EQ(execution_devices[7][0], + EXPECT_EQ(tpu_devices[6][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[7][0].device, "/job:worker/replica:0/task:1/device:TPU:3"); + EXPECT_EQ(tpu_devices[7][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); EXPECT_FALSE(tpu_device_assignment.xla_device_assignment.hasValue()); } @@ -410,30 +426,46 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) { TF_ASSERT_OK(status_or.status()); - auto& tpu_device_assignment = status_or.ValueOrDie(); + const auto& tpu_device_assignment = status_or.ValueOrDie(); EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 4); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 2); + const auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 4); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 2); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[0][1], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][1].device, "/job:worker/replica:0/task:1/device:TPU:3"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:0/device:TPU:1"); - EXPECT_EQ(execution_devices[1][1], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][1].device, "/job:worker/replica:0/task:1/device:TPU:2"); - EXPECT_EQ(execution_devices[2][0], + EXPECT_EQ(tpu_devices[1][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][0].device, "/job:worker/replica:0/task:0/device:TPU:3"); - EXPECT_EQ(execution_devices[2][1], + EXPECT_EQ(tpu_devices[2][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[2][1].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[3][0], + EXPECT_EQ(tpu_devices[2][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][0].device, "/job:worker/replica:0/task:0/device:TPU:2"); - EXPECT_EQ(execution_devices[3][1], + EXPECT_EQ(tpu_devices[3][0].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[3][1].device, "/job:worker/replica:0/task:1/device:TPU:1"); + EXPECT_EQ(tpu_devices[3][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment; ASSERT_TRUE(xla_device_assignment.hasValue()); @@ -511,23 +543,35 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { EXPECT_EQ(tpu_device_assignment.compilation_device, "/job:worker/replica:0/task:0/device:CPU:0"); - auto& execution_devices = tpu_device_assignment.execution_devices; - ASSERT_EQ(execution_devices.size(), 2); - for (const auto& replica_execution_device : execution_devices) - ASSERT_EQ(replica_execution_device.size(), 3); + auto& tpu_devices = tpu_device_assignment.tpu_devices; + ASSERT_EQ(tpu_devices.size(), 2); + for (const auto& replica_tpu_devices : tpu_devices) + ASSERT_EQ(replica_tpu_devices.size(), 3); - EXPECT_EQ(execution_devices[0][0], + EXPECT_EQ(tpu_devices[0][0].device, "/job:worker/replica:0/task:1/device:TPU:1"); - EXPECT_EQ(execution_devices[0][1], + EXPECT_EQ(tpu_devices[0][0].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][1].device, "/job:worker/replica:0/task:1/device:TPU:0"); - EXPECT_EQ(execution_devices[0][2], + EXPECT_EQ(tpu_devices[0][1].host, + "/job:worker/replica:0/task:1/device:CPU:0"); + EXPECT_EQ(tpu_devices[0][2].device, "/job:worker/replica:0/task:2/device:TPU:0"); - EXPECT_EQ(execution_devices[1][0], + EXPECT_EQ(tpu_devices[0][2].host, + "/job:worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][0].device, "/job:worker/replica:0/task:2/device:TPU:1"); - EXPECT_EQ(execution_devices[1][1], + EXPECT_EQ(tpu_devices[1][0].host, + "/job:worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][1].device, "/job:worker/replica:0/task:0/device:TPU:0"); - EXPECT_EQ(execution_devices[1][2], + EXPECT_EQ(tpu_devices[1][1].host, + "/job:worker/replica:0/task:0/device:CPU:0"); + EXPECT_EQ(tpu_devices[1][2].device, "/job:worker/replica:0/task:0/device:TPU:1"); + EXPECT_EQ(tpu_devices[1][2].host, + "/job:worker/replica:0/task:0/device:CPU:0"); auto& xla_device_assignment = tpu_device_assignment.xla_device_assignment; ASSERT_TRUE(xla_device_assignment.hasValue()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index aef336330e0..083a5abf840 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -202,23 +202,23 @@ mlir::LogicalResult HandleTileShardedInputs( } // namespace mlir::LogicalResult ExtractInputsForLogicalDevices( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, - mlir::OpBuilder* builder, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder, llvm::SmallVectorImpl>* input_list) { // Initialize the input list for each logical devices. input_list->reserve(num_cores_per_replica); for (int i = 0; i < num_cores_per_replica; ++i) input_list->emplace_back(llvm::SmallVector()); - llvm::SmallVector launch_func_inputs( - launch_func.getOperands()); + llvm::SmallVector cluster_func_inputs( + cluster_func.getOperands()); auto sharding_attrs = - launch_func.getOperation()->getAttrOfType( + cluster_func.getOperation()->getAttrOfType( kInputShardingAttr); // If sharding attribute does not exist, then all inputs are placed on 0th // logical core by default. if (!sharding_attrs) { - (*input_list)[0] = launch_func_inputs; + (*input_list)[0] = cluster_func_inputs; return mlir::success(); } @@ -229,7 +229,7 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) { const auto& sharding_attr = sharding_attr_and_index.value(); const auto input_index = sharding_attr_and_index.index(); - const auto& input_value = launch_func_inputs[input_index]; + const auto& input_value = cluster_func_inputs[input_index]; xla::OpSharding sharding; sharding.ParseFromString( @@ -239,11 +239,11 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( if (input_sharding_type == xla::OpSharding::OTHER) { llvm::SmallVector tiled_inputs; auto result = HandleTileShardedInputs( - launch_func.getLoc(), sharding, input_value, builder, &tiled_inputs); + cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs); if (mlir::failed(result)) return mlir::failure(); if (tiled_inputs.size() != num_cores_per_replica) - launch_func.emitError(llvm::formatv( + cluster_func.emitError(llvm::formatv( "incorrect {0}-th tiled input sharding received. " "Product of tile sharding splits({1}) must be equal to " "number of logical devices : {2}", @@ -265,36 +265,37 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( } mlir::LogicalResult ParseAndValidateOutputSharding( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::SmallVector* output_sharding_list) { - output_sharding_list->reserve(launch_func.getNumResults()); + output_sharding_list->reserve(cluster_func.getNumResults()); const auto output_sharding_attrs = - launch_func.getOperation()->getAttrOfType( + cluster_func.getOperation()->getAttrOfType( kOutputShardingAttr); if (!output_sharding_attrs) - return launch_func.emitError( - "output_sharding_configuration missing from launch func"); + return cluster_func.emitError( + "output_sharding_configuration missing from cluster func"); - if (output_sharding_attrs.size() != launch_func.getNumResults()) - return launch_func.emitError("incorrect number of output sharding"); + if (output_sharding_attrs.size() != cluster_func.getNumResults()) + return cluster_func.emitError("incorrect number of output sharding"); for (auto output_sharding_and_index : llvm::enumerate(output_sharding_attrs)) { const auto& output_sharding = output_sharding_and_index.value(); const int sharding_index = output_sharding_and_index.index(); if (!output_sharding.isa()) - return launch_func.emitError(llvm::formatv( + return cluster_func.emitError(llvm::formatv( "non-string output sharding at index {0}", sharding_index)); xla::OpSharding sharding; if (!sharding.ParseFromString( output_sharding.cast().getValue().str())) - return launch_func.emitError("incorrect sharding format for outputs"); + return cluster_func.emitError("incorrect sharding format for outputs"); if (sharding.type() == xla::OpSharding::OTHER && sharding.tile_assignment_devices_size() != num_cores_per_replica) - return launch_func.emitError(llvm::formatv( + return cluster_func.emitError(llvm::formatv( "incorrect sharding format for outputs. Number of " "tiled outputs({0}) must match the number of logical " "devices({1})", @@ -303,7 +304,7 @@ mlir::LogicalResult ParseAndValidateOutputSharding( if (sharding.type() == xla::OpSharding::MAXIMAL && ((sharding.tile_assignment_devices(0) >= num_cores_per_replica) || (sharding.tile_assignment_devices(0) < 0))) - return launch_func.emitError(llvm::formatv( + return cluster_func.emitError(llvm::formatv( "incorrect sharding format for outputs. Maximal " "sharding should be assigned to device id in range " "[0, {0}). Currently assigned to {1}", @@ -323,15 +324,15 @@ bool IsAssignedToLogicalDevice(const int core_id, } // Returns the index of the return value of region in -// `tf_device.parallel_execute` that represents launch func output at -// index |launch_func_output_index|. Regions of parallel_execute may +// `tf_device.parallel_execute` that represents cluster func output at +// index |cluster_func_output_index|. Regions of parallel_execute may // have different return values depending on outside sharding // configuration. -int MapLaunchOutputIndexWithRegionOutputIndex( +int MapClusterOutputIndexWithRegionOutputIndex( llvm::ArrayRef output_sharding_config, const int core_id, - const int launch_func_output_index) { + const int cluster_func_output_index) { int region_output_index = 0; - for (int output_index = 0; output_index < launch_func_output_index; + for (int output_index = 0; output_index < cluster_func_output_index; ++output_index) { const auto& sharding = output_sharding_config[output_index]; if (sharding.type() != xla::OpSharding::MAXIMAL || @@ -344,8 +345,8 @@ int MapLaunchOutputIndexWithRegionOutputIndex( // Merges outputs from TPU computation for tile-sharded outputs. mlir::LogicalResult HandleTileShardedOutputs( - const int launch_func_output_index, const xla::OpSharding& sharding, - const mlir::Location& location, mlir::Value launch_func_output, + const int cluster_func_output_index, const xla::OpSharding& sharding, + const mlir::Location& location, mlir::Value cluster_func_output, mlir::tf_device::ParallelExecuteOp parallel_execute, mlir::OpBuilder* builder) { // Inject concat ops after parallel_execute to merge outputs from @@ -357,8 +358,8 @@ mlir::LogicalResult HandleTileShardedOutputs( llvm::SmallVector outputs_to_merge; outputs_to_merge.reserve(sharding.tile_assignment_devices_size()); for (const auto logical_device_id : sharding.tile_assignment_devices()) { - const int region_output_index = MapLaunchOutputIndexWithRegionOutputIndex( - sharding, logical_device_id, launch_func_output_index); + const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex( + sharding, logical_device_id, cluster_func_output_index); const auto output_from_logical_device = parallel_execute.GetRegionOutputs( logical_device_id)[region_output_index]; outputs_to_merge.emplace_back(output_from_logical_device); @@ -393,30 +394,30 @@ mlir::LogicalResult HandleTileShardedOutputs( } assert(outputs_to_merge.size() == 1); - launch_func_output.replaceAllUsesWith(outputs_to_merge[0]); + cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]); return mlir::success(); } mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( const mlir::Location& location, - const mlir::TensorType launch_func_output_type, + const mlir::TensorType cluster_func_output_type, const xla::OpSharding& output_sharding, mlir::Type* tiled_logical_computation_type) { auto new_output_shape = - llvm::to_vector<4>(launch_func_output_type.getShape()); + llvm::to_vector<4>(cluster_func_output_type.getShape()); for (auto dimension_and_output_splits : llvm::enumerate(output_sharding.tile_assignment_dimensions())) { const auto dimension_index = dimension_and_output_splits.index(); const auto output_splits = dimension_and_output_splits.value(); - const auto& output_shape = launch_func_output_type.getShape(); + const auto output_shape = cluster_func_output_type.getShape(); if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) { - *tiled_logical_computation_type = launch_func_output_type; + *tiled_logical_computation_type = cluster_func_output_type; break; } auto output_shape_at_dim = - launch_func_output_type.getShape()[dimension_index]; + cluster_func_output_type.getShape()[dimension_index]; if (output_shape_at_dim % output_splits != 0) { mlir::emitError( location, @@ -432,7 +433,7 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( } *tiled_logical_computation_type = mlir::RankedTensorType::get( - new_output_shape, launch_func_output_type.getElementType()); + new_output_shape, cluster_func_output_type.getElementType()); return mlir::success(); } @@ -441,34 +442,34 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( const int core_id, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, llvm::SmallVectorImpl* output_types) { - output_types->reserve(launch_func.getNumResults()); + output_types->reserve(cluster_func.getNumResults()); - for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { + for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) { const auto output_index = result_and_index.index(); const auto& output_sharding = output_sharding_config[output_index]; const auto output_sharding_type = output_sharding.type(); - const auto& launch_func_output_type = + const auto cluster_func_output_type = result_and_index.value().getType().cast(); - // If output shape of launch func is statically known and output is tiled - // sharded, then the corresponding output shape of launch func must be + // If output shape of cluster func is statically known and output is tiled + // sharded, then the corresponding output shape of cluster func must be // evenly divisible number of shardings. if (output_sharding_type == xla::OpSharding::OTHER) { mlir::Type tiled_logical_computation_type; - if (launch_func_output_type.hasRank()) { + if (cluster_func_output_type.hasRank()) { auto result = ValidateAndGetTiledExecuteOutputShape( - launch_func.getLoc(), launch_func_output_type, output_sharding, + cluster_func.getLoc(), cluster_func_output_type, output_sharding, &tiled_logical_computation_type); if (mlir::failed(result)) return mlir::failure(); } else { - tiled_logical_computation_type = launch_func_output_type; + tiled_logical_computation_type = cluster_func_output_type; } output_types->emplace_back(tiled_logical_computation_type); } else if (output_sharding_type == xla::OpSharding::REPLICATED || IsAssignedToLogicalDevice(core_id, output_sharding)) { - output_types->emplace_back(launch_func_output_type); + output_types->emplace_back(cluster_func_output_type); } } @@ -478,17 +479,17 @@ mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( void RemapOutputsFromLogicalDevices( const mlir::Location& location, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::tf_device::ParallelExecuteOp parallel_execute, mlir::OpBuilder* builder) { - for (auto result_and_index : llvm::enumerate(launch_func.getResults())) { + for (auto result_and_index : llvm::enumerate(cluster_func.getResults())) { const auto output_index = result_and_index.index(); - const auto& launch_func_output = result_and_index.value(); + const auto cluster_func_output = result_and_index.value(); const auto& output_sharding = output_sharding_config[output_index]; const auto output_sharding_type = output_sharding.type(); if (output_sharding_type == xla::OpSharding::OTHER) { HandleTileShardedOutputs(output_index, output_sharding, location, - launch_func_output, parallel_execute, builder); + cluster_func_output, parallel_execute, builder); continue; } @@ -497,13 +498,13 @@ void RemapOutputsFromLogicalDevices( logical_device_id = output_sharding.tile_assignment_devices(0); // For maximal sharding configuration, correctly remap outputs from - // parallel_execute region to users of the launch func. - const int region_output_index = MapLaunchOutputIndexWithRegionOutputIndex( + // parallel_execute region to users of the cluster func. + const int region_output_index = MapClusterOutputIndexWithRegionOutputIndex( output_sharding_config, logical_device_id, output_index); const auto output_from_logical_device = parallel_execute.GetRegionOutputs( logical_device_id)[region_output_index]; - launch_func_output.replaceAllUsesWith(output_from_logical_device); + cluster_func_output.replaceAllUsesWith(output_from_logical_device); } } @@ -522,7 +523,7 @@ llvm::SmallVector, 4> GetMetadataArgumentMapping( const auto& sharding = arg_and_idx.value().sharding(); const int64_t idx = arg_and_idx.index(); - const auto& sharding_type = sharding.type(); + const auto sharding_type = sharding.type(); if (sharding_type == xla::OpSharding::OTHER) { for (const auto& device : sharding.tile_assignment_devices()) input_mappings[device].push_back(idx); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index 52a633d3111..69bc092927d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -32,19 +32,20 @@ namespace tensorflow { extern const char* const kInputShardingAttr; extern const char* const kOutputShardingAttr; -// Parses "input_sharding_configuration" attribute and returns a list where -// i-th element is a list of mlir::Value's which represent inputs for the -// TPU computation correponding to i-th logical device. If the attribute -// does not exist, the all inputs are placed on logical core 0. +// Parses "input_sharding_configuration" attribute and returns a list where i-th +// element is a list of mlir::Value's which represent inputs for the TPU +// computation correponding to i-th logical device. If the attribute does not +// exist, the all inputs are placed on logical core 0. mlir::LogicalResult ExtractInputsForLogicalDevices( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, - mlir::OpBuilder* builder, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::OpBuilder* builder, llvm::SmallVectorImpl>* input_list); -// Extracts a list of OpSharding that represent output sharding configuration -// of `tf_device.launch`. +// Extracts a list of OpSharding that represent output sharding configuration of +// `tf_device.cluster`. mlir::LogicalResult ParseAndValidateOutputSharding( - const int num_cores_per_replica, mlir::tf_device::LaunchFuncOp launch_func, + const int num_cores_per_replica, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::SmallVector* output_sharding_list); // Retrieves output types for TPUExecute op representing execution for provided @@ -52,15 +53,15 @@ mlir::LogicalResult ParseAndValidateOutputSharding( // different outputs depending on the output sharding configuration. mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( const int core_id, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, llvm::SmallVectorImpl* output_types); // Remaps outputs of `tf_device.parallel_execute` op that represent concurrent -// execution of the `tf_device.launch_func` with its users. +// execution of the `tf_device.cluster_func` with its users. void RemapOutputsFromLogicalDevices( const mlir::Location& location, llvm::ArrayRef output_sharding_config, - mlir::tf_device::LaunchFuncOp launch_func, + mlir::tf_device::ClusterFuncOp cluster_func, mlir::tf_device::ParallelExecuteOp parallel_execute, mlir::OpBuilder* builder); diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 62b862f5e21..2e1528e0d60 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -104,26 +104,24 @@ int main(int argc, char** argv) { return 1; } + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_vector); + if (import_saved_model_object_graph) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); - std::vector exported_names = - absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); mlir::MLIRContext context; auto module = tensorflow::SavedModelObjectGraphToMlirImport( - input_filename, tags, absl::Span(exported_names), - &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); } else if (import_saved_model_signature_defs) { - std::unordered_set tags = - absl::StrSplit(saved_model_tags, ','); mlir::MLIRContext context; auto module = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, &context); + input_filename, tags, exported_names, &context); if (!module) return 1; module->print(output->os()); diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 9b731d2c912..ac629ac4573 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -1,4 +1,5 @@ load("//third_party/mlir:tblgen.bzl", "gentbl") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( default_visibility = ["//visibility:public"], @@ -39,7 +40,7 @@ gentbl( "ir/tfjs_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -131,10 +132,106 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", ], ) + +cc_library( + name = "json_translate_lib", + srcs = [ + "translate/json_translate.cc", + ], + hdrs = [ + "translate/json_translate.h", + ], + deps = [ + ":tensorflow_js", + ":tensorflow_js_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:export_utils", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Translation", + ], + alwayslink = 1, +) + +cc_library( + name = "tf_to_tfjs_json", + srcs = ["translate/tf_to_tfjs_json.cc"], + hdrs = [ + "translate/tf_to_tfjs_json.h", + ], + deps = [ + ":json_translate_lib", + ":tfjs_optimize", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) + +tf_cc_binary( + name = "json_translate", + deps = [ + ":json_translate_lib", + "@llvm-project//mlir:MlirTranslateMain", + ], +) + +filegroup( + name = "tf_tfjs_translate_main", + srcs = [ + "translate/tf_tfjs_translate.cc", + ], +) + +tf_cc_binary( + name = "tf_tfjs_translate", + srcs = [":tf_tfjs_translate_main"], + deps = [ + ":json_translate_lib", + ":tensorflow_js_passes", + ":tf_to_tfjs_json", + ":tfjs_optimize", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h index 318895de79c..545183a052b 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Interfaces/SideEffects.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project + namespace mlir { namespace tfjs { diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td index 172347bc0f5..134aa010d8c 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.td @@ -23,7 +23,7 @@ limitations under the License. #define TFJS_DIALECT include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// // TensorFlow.js dialect definitions diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD new file mode 100644 index 00000000000..5c8d37da2f0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/BUILD @@ -0,0 +1,23 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +licenses(["notice"]) + +glob_lit_tests( + data = [ + ":test_utilities", + ], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "pbtxt", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir/tfjs:tf_tfjs_translate", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt new file mode 100644 index 00000000000..f6a324fdc13 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/add.pbtxt @@ -0,0 +1,78 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Mul -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "Add" + op: "Add" + input: "input0" + input: "input1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } +} +node { + name: "Mul" + op: "Mul" + input: "Add" + input: "Add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +versions { + producer: 27 +} + +# CHECK: "name": "input0" +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "input1", +# CHECK-NEXT: "op": "Placeholder" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Add" +# CHECK-NEXT: "op": "AddV2" +# CHECK-NEXT: "input": +# CHECK-NEXT: "input0" +# CHECK-NEXT: "input1" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul1" +# CHECK-NEXT: "op": "Mul" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Add" +# CHECK-NEXT: "Add" +# CHECK: "type": "DT_INT32" +# CHECK: "name": "Mul" +# CHECK-NEXT: "op": "_Retval" +# CHECK-NEXT: "input": +# CHECK-NEXT: "Mul1" +# CHECK: "type": "DT_INT32" +# CHECK: "library" +# CHECK: "versions" +# CHECK: "producer": 27 + diff --git a/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt new file mode 100644 index 00000000000..810db71f5e0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tests/e2e/prelu.pbtxt @@ -0,0 +1,175 @@ +# RUN: tf_tfjs_translate %s -tf-input-arrays=input0 -tf-input-data-types=DT_FLOAT -tf-input-shapes=10 -tf-output-arrays=Add -tf-custom-opdefs="name: 'Prelu' input_arg: { name: 'x' type: DT_FLOAT } input_arg: { name: 'alpha' type: DT_FLOAT } output_arg: { name: 'c' type: DT_FLOAT }" -o - | FileCheck %s --dump-input-on-failure +# Add two tensor<4xi32> inputs and return the result + +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + experimental_debug_info { + } +} +node { + name: "alpha" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + experimental_debug_info { + } +} +node { + name: "Relu" + op: "Relu" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Neg" + op: "Neg" + input: "input0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Relu1" + op: "Relu" + input: "Neg" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Mul" + op: "Mul" + input: "alpha" + input: "Relu1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "Add" + op: "Add" + input: "Relu" + input: "Mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "Add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { +} +versions { + producer: 344 +} + +# CHECK: "node": +# CHECK: "name": "input0", +# CHECK-NEXT: "op": "Placeholder", +# CHECK-NEXT: "attr": +# CHECK: "type": "DT_FLOAT" +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul", +# CHECK-NEXT: "op": "Const", +# CHECK-NEXT: "attr": +# CHECK: "value": +# CHECK: "tensor": +# CHECK: "dtype": "DT_FLOAT", +# CHECK: "tensorShape": {}, +# CHECK: "floatVal": +# CHECK: -0.5 +# CHECK: "name": "Add.Relu.Neg.Relu1.Mul1", +# CHECK-NEXT: "op": "Prelu", +# CHECK-NEXT: "input": +# CHECK: "input0", +# CHECK: "Add.Relu.Neg.Relu1.Mul" +# CHECK: "attr": +# CHECK: "_output_shapes": +# CHECK: "list": +# CHECK: "shape": +# CHECK: "dim": +# CHECK: "size": "10" +# CHECK: "experimentalDebugInfo": {} +# CHECK: "name": "Add", +# CHECK-NEXT: "op": "_Retval", +# CHECK-NEXT: "input": +# CHECK: "Add.Relu.Neg.Relu1.Mul1" +# CHECK: "attr": +# CHECK: "T": +# CHECK: "type": "DT_FLOAT" +# CHECK: "library": {}, +# CHECK: "versions": +# CHECK: "producer": 344 + diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc index 631bb1ae2af..a445937570e 100644 --- a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc +++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" @@ -47,6 +46,11 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) { // Canonicalize, CSE etc. pm->addNestedPass(mlir::createCanonicalizerPass()); pm->addNestedPass(mlir::createCSEPass()); + + // raise to executor dialect in order to use GraphDef converter + pm->addNestedPass( + mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm->addNestedPass(mlir::CreateBreakUpIslandsPass()); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc new file mode 100644 index 00000000000..7f4b8ffae09 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" + +using mlir::ModuleOp; +using mlir::TranslateFromMLIRRegistration; +using std::string; +using tensorflow::Status; +using xla::StatusOr; + +// Translates the given MLIR module in the TFJS dialect to TFJS JSON +// format. Returns false on success. +// +bool tfjs::MlirToJSONTranslateFunction(ModuleOp module, + std::string* serialized_json) { + string json_output; + // Allow TF to treat TFJS ops as TF ops. + if (!tensorflow::AddTensorFlowOpPrefix("tfjs.").ok()) { + LOG(ERROR) << "Failed to add tfjs op prefix."; + return false; + } + tensorflow::GraphExportConfig confs; + confs.export_shapes = true; + confs.export_library = true; + tensorflow::FunctionLibraryDefinition flib_def( + tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); + absl::flat_hash_set control_ret_nodes; + auto graph = absl::make_unique(flib_def); + auto status = tensorflow::ConvertMlirToGraph(module, confs, &graph, &flib_def, + &control_ret_nodes); + if (!status.ok()) { + LOG(ERROR) << "Graph export failed: " << status; + return false; + } + auto graphdef = absl::make_unique(); + graph->ToGraphDef(graphdef.get()); + + // Replace the _Arg nodes of the main function with Placeholder op. + auto nodes = graphdef->mutable_node(); + for (const auto& node : llvm::enumerate(*nodes)) { + if (node.value().op() == "_Arg") { + nodes->Mutable(node.index())->set_op("Placeholder"); + } + } + + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + auto jsonStatus = tensorflow::protobuf::util::MessageToJsonString( + *graphdef, &json_output, json_options); + if (!jsonStatus.ok()) { + LOG(ERROR) << "Proto2Json failed: " << status; + return false; + } + *serialized_json = std::move(json_output); + return true; +} + +static mlir::LogicalResult MlirToJSONFileTranslateFunction( + ModuleOp module, llvm::raw_ostream& output) { + std::string serialized_json; + if (!tfjs::MlirToJSONTranslateFunction(module, &serialized_json)) + return mlir::failure(); + + output << serialized_json; + return mlir::success(); +} + +static TranslateFromMLIRRegistration MLIRToJSONFileTranslate( + "mlir-to-tfjs-json", MlirToJSONFileTranslateFunction); diff --git a/tensorflow/compiler/mlir/tfjs/translate/json_translate.h b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h new file mode 100644 index 00000000000..0a931f770ad --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/json_translate.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ + +#include + +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/core/lib/core/status.h" + +namespace tfjs { + +// Translates the given MLIR `module` into a JSON string. Returns true if +// translation fails, otherwise returns false. +bool MlirToJSONTranslateFunction(mlir::ModuleOp module, + std::string* serialized_json); +} // namespace tfjs + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_JSON_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc new file mode 100644 index 00000000000..e735a3c7b8c --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc @@ -0,0 +1,173 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/str_split.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.h" +#include "tensorflow/compiler/mlir/tfjs/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using llvm::cl::opt; +using mlir::MLIRContext; +using stream_executor::port::StatusOr; + +// NOLINTNEXTLINE +opt input_file_name(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +opt import_saved_model_object_graph( + "savedmodel-objectgraph-to-mlir", + llvm::cl::desc("Import a saved model to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt import_saved_model_signature_defs( + "savedmodel-signaturedefs-to-mlir", + llvm::cl::desc("Import a saved model V1 to its MLIR representation"), + llvm::cl::value_desc("dir")); + +// NOLINTNEXTLINE +opt saved_model_tags( + "tf-savedmodel-tags", + llvm::cl::desc("Tags used to indicate which MetaGraphDef to import, " + "separated by ','"), + llvm::cl::init("serve")); + +// NOLINTNEXTLINE +opt saved_model_exported_names( + "tf-savedmodel-exported-names", + llvm::cl::desc("Names to export from SavedModel, separated by ','. Empty " + "(the default) means export all."), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt output_file_name("o", llvm::cl::desc(""), + llvm::cl::value_desc("filename"), + llvm::cl::init("-")); +// NOLINTNEXTLINE +opt input_mlir( + "input-mlir", + llvm::cl::desc("Take input TensorFlow model in textual MLIR instead of " + "GraphDef format"), + llvm::cl::init(false), llvm::cl::Hidden); +// NOLINTNEXTLINE +opt output_mlir( + "output-mlir", + llvm::cl::desc("Output MLIR rather than JSON for the generated TFJS model"), + llvm::cl::init(false)); + +// The following approach allows injecting opdefs in addition +// to those that are already part of the global TF registry to be linked in +// prior to importing the graph. The primary goal is for support of custom ops. +// This is not intended to be a general solution for custom ops for the future +// but mainly for supporting older models like mobilenet_ssd. More appropriate +// mechanisms, such as op hints or using functions to represent composable ops +// like https://github.com/tensorflow/community/pull/113 should be encouraged +// going forward. +// NOLINTNEXTLINE +llvm::cl::list custom_opdefs( + "tf-custom-opdefs", llvm::cl::desc("List of custom opdefs when importing " + "graphdef")); + +// Debugging flag to print function mapping in the JSON. +// NOLINTNEXTLINE +static opt print_function_result_mapping( + "print-function-result-mapping", + llvm::cl::desc( + "Print the mapping of function result to json output buffer"), + llvm::cl::init(false)); + +enum TranslationStatus { kTrSuccess, kTrFailure }; + +static int PrintFunctionResultMapping(const std::string& result) { + std::cout << result << std::endl; + return kTrSuccess; +} + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, + "TF GraphDef to TFJS JSON converter\n"); + + MLIRContext context; + llvm::SourceMgr source_mgr; + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); + + StatusOr module; + + if (import_saved_model_object_graph || import_saved_model_signature_defs) { + if (input_mlir) + module = tensorflow::errors::InvalidArgument( + "Importing saved model should not have input_mlir set"); + module = tensorflow::ImportSavedModel( + import_saved_model_object_graph, import_saved_model_signature_defs, + custom_opdefs, input_file_name, saved_model_tags, + saved_model_exported_names, &context); + } else { + module = tensorflow::LoadFromGraphdefOrMlirSource( + input_file_name, input_mlir, custom_opdefs, debug_info_file, + input_arrays, input_dtypes, input_shapes, output_arrays, + /*prune_unused_nodes=*/true, &source_mgr, &context); + } + + // If errors occur, the library call in the above already logged the error + // message. So we can just return here. + if (!module.ok()) return kTrFailure; + + mlir::PassManager pm(&context); + + tensorflow::AddTFToTFJSConversionPasses(&pm); + + std::string result; + auto status = tensorflow::ConvertTFOpsToTfjsJSON(module.ValueOrDie().get(), + output_mlir, &result, &pm); + if (!status.ok()) return kTrFailure; + + std::string error_msg; + auto output = mlir::openOutputFile(output_file_name, &error_msg); + if (output == nullptr) { + llvm::errs() << error_msg << '\n'; + return kTrFailure; + } + output->os() << result; + output->keep(); + + // Print out debugging info related to function mapping. + if (print_function_result_mapping) return PrintFunctionResultMapping(result); + return kTrSuccess; +} diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc new file mode 100644 index 00000000000..7dc9ea049ba --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tfjs/translate/json_translate.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::OwningModuleRef; +using stream_executor::port::StatusOr; + +namespace { +tensorflow::Status RegisterCustomOps( + const std::vector& extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + tensorflow::OpDef opdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string, + &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return errors::InvalidArgument("fail to parse extra OpDef"); + } + // Register extra opdefs. + tensorflow::OpRegistry::Global()->Register( + [opdef](tensorflow::OpRegistrationData* op_reg_data) -> Status { + *op_reg_data = tensorflow::OpRegistrationData(opdef); + return Status::OK(); + }); + } + return Status::OK(); +} +} // namespace + +StatusOr LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, MLIRContext* context) { + // Set up the input file. + std::string error_message; + auto file = mlir::openInputFile(input_filename, &error_message); + if (!file) { + llvm::errs() << error_message << "\n"; + return errors::InvalidArgument("fail to open input file"); + } + + if (input_mlir) { + source_mgr->AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + return OwningModuleRef(mlir::parseSourceFile(*source_mgr, context)); + } + + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + + return tensorflow::GraphdefToMlirTranslateFunction( + file->getBuffer(), debug_info_file, input_arrays, input_dtypes, + input_shapes, output_arrays, /*control_output_arrays=*/"", + prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, + /*graph_as_function=*/false, /*upgrade_legacy=*/true, + /*enable_shape_inference=*/true, context); +} + +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager) { + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), + /*propagate=*/true); + if (failed(pass_manager->run(module))) { + return statusHandler.ConsumeStatus(); + } + + if (export_to_mlir) { + llvm::raw_string_ostream os(*result); + module.print(os); + return Status::OK(); + } + + return tfjs::MlirToJSONTranslateFunction(module, result) + ? Status::OK() + : statusHandler.ConsumeStatus(); +} + +StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context) { + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + std::vector exported_names_in_vector = + absl::StrSplit(saved_model_exported_names, ',', absl::SkipEmpty()); + absl::Span exported_names(exported_names_in_vector); + if (import_saved_model) { + auto module = tensorflow::SavedModelObjectGraphToMlirImport( + input_filename, tags, absl::Span(exported_names), context); + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else if (import_saved_model_v1) { + auto module = tensorflow::SavedModelSignatureDefsToMlirImport( + input_filename, tags, exported_names, context); + + if (!module) + return tensorflow::errors::InvalidArgument("fail to open input file"); + TF_RETURN_IF_ERROR(RegisterCustomOps(extra_tf_opdefs)); + return module; + } else { + return tensorflow::errors::InvalidArgument( + "Should be either saved model v1 or v2"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h new file mode 100644 index 00000000000..d68f0e7d46e --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_to_tfjs_json.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ +#define TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR +// source into a MLIR module. If `input_mlir` is true, load from a MLIR source +// file; otherwise, load from a GraphDef. +// Setting prune_unused_nodes to true, would prune unreachable nodes if +// output_arrays is specified. +stream_executor::port::StatusOr +LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + const std::vector& extra_tf_opdefs, + absl::string_view debug_info_file, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, bool prune_unused_nodes, + llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); + +// Load Saved model (either v1 or v2) into MLIR. +stream_executor::port::StatusOr ImportSavedModel( + bool import_saved_model, bool import_saved_model_v1, + const std::vector& extra_tf_opdefs, + const std::string& input_filename, const std::string& saved_model_tags, + const std::string& saved_model_exported_names, mlir::MLIRContext* context); + +// Taking a MLIR module in TF executor dialect and a set of parameters, +// applies a set of passes to convert the module to TFJS dialect and +// serializes the result to JSON string. +// If `export_to_mlir` is true, the result is exported in MLIR text format, +// otherwise exported in JSON. +Status ConvertTFOpsToTfjsJSON(mlir::ModuleOp module, bool export_to_mlir, + std::string* result, + mlir::PassManager* pass_manager); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFJS_TRANSLATE_TF_TO_TFJS_JSON_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD new file mode 100644 index 00000000000..27a8dbd2809 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -0,0 +1,50 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +licenses(["notice"]) + +cc_library( + name = "cubin_creator", + srcs = ["cubin_creator.cc"], + hdrs = ["cubin_creator.h"], + copts = if_cuda(["-DGOOGLE_CUDA=1"]), + deps = [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TargetNVVMIR", + "@llvm-project//mlir:Transforms", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", # buildcleaner: keep + "//tensorflow/compiler/mlir/xla:xla_unfuse_batch_norm", # buildcleaner: keep + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service/gpu:stream_executor_util", + "//tensorflow/compiler/xla/service/gpu:target_constants", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", + "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core:lib", + ] + if_cuda(["//tensorflow/stream_executor/gpu:asm_compiler"]), +) + +tf_cc_binary( + name = "tf_to_cubin", + srcs = ["tf_to_cubin.cc"], + visibility = ["//tensorflow/core/kernels/cubin_headers:__pkg__"], + deps = [ + ":cubin_creator", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc new file mode 100644 index 00000000000..b1c4b1beae1 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -0,0 +1,264 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- cubin_creator.cc -----------------------------------------*- C++ -*-===// +// +// This file implements the function to compile a TF kernel function to a cubin. +// +//===----------------------------------------------------------------------===// +#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/NVVMIR.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" +#include "tensorflow/core/platform/cuda_libdevice_path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" +#if GOOGLE_CUDA +#include "tensorflow/stream_executor/gpu/asm_compiler.h" +#endif + +namespace { +using tensorflow::Status; +using xla::InternalError; +using xla::StatusOr; + +StatusOr GetLibdeviceDir( + const xla::HloModuleConfig& hlo_module_config) { + for (const std::string& cuda_root : tensorflow::CandidateCudaRoots( + hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { + std::string libdevice_dir = + tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + return InternalError( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); +} + +struct MaterializeBroadcastsPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::ConversionTarget conversionTarget(getContext()); + mlir::OwningRewritePatternList conversionPatterns; + + // Consider the xla_hlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + + mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(), + &conversionTarget); + mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(), + &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +struct UnfuseBatchNormPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + } +}; + +Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { + mlir::PassManager pm(module.getContext()); + auto enable_if_vlog_is_on = [](mlir::Pass* pass, mlir::Operation* op) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, llvm::dbgs()); + pm.addNestedPass(mlir::xla_hlo::createLegalizeTFPass(false)); + pm.addNestedPass( + absl::make_unique()); + pm.addNestedPass(absl::make_unique()); + pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass()); + pm.addNestedPass(mlir::xla_lhlo::createLhloCopyRemovalPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering TF to LHLO failed."); + } + return Status::OK(); +} + +struct PropagateStaticKnowledge + : public mlir::PassWrapper> { + explicit PropagateStaticKnowledge(mlir::FunctionType type, + llvm::ArrayRef same_shape_) + : func_type(type), same_shape(same_shape_) {} + + void runOnOperation() override { + // We know due to tensorflow ABI that the offset is always 0 and that the + // innermost stride is always 1. To make this visible to the compiler, + // we insert constants into the code and replace usages accordingly. + // We do not change the signature so that we keep a somewhat stable ABI + // that is easy to undertand by tools. + mlir::LLVM::LLVMFuncOp func = getOperation(); + mlir::OpBuilder b(func.getBody()); + auto index_type = func.getArgument(3).getType(); + mlir::Value one = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1)); + mlir::Value zero = b.create( + func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0)); + uint32_t arg_pos = 0; + std::vector positions; + for (mlir::Type arg_type : func_type.getInputs()) { + positions.push_back(arg_pos); + func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); + arg_pos += 3 + arg_type.cast().getRank() * 2; + func.getArgument(arg_pos - 1).replaceAllUsesWith(one); + } + + // If we have knowledge that some arguments have the same shape, we + // can use that here. Simply replace usages of the shape parameters within + // the function body to a single shape parameter. + if (!same_shape.empty()) { + auto first = same_shape.front(); + auto first_offset = positions.at(first); + mlir::ShapedType first_type = + func_type.getInput(first).cast(); + uint32_t rank = first_type.getRank(); + for (auto same : same_shape.drop_front(1)) { + uint32_t same_offset = positions.at(same); + auto same_type = func_type.getInput(same).cast(); + if (same_type.getRank() != rank) { + func.emitOpError() << "same shape constraints on arguments with " + "non-matching shapes: #" + << first << " and #" << same; + signalPassFailure(); + } + + for (uint32_t i = 0; i < 2 * rank; ++i) { + // Replace uses for second arg data with first arg. + auto same_arg = func.getArgument(same_offset + 3 + i); + auto first_arg = func.getArgument(first_offset + 3 + i); + same_arg.replaceAllUsesWith(first_arg); + } + } + } + } + + mlir::FunctionType func_type; + llvm::ArrayRef same_shape; +}; + +Status PropagateStaticShapeKnowledgeToKernel( + mlir::ModuleOp module, llvm::ArrayRef same_shape) { + // Grab the original signature from the single function. + auto func = *module.getBody()->op_begin(); + + mlir::PassManager pm(module.getContext()); + auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { + return VLOG_IS_ON(1); + }; + pm.enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/enable_if_vlog_is_on, + /*printModuleScope=*/false, + /*printAfterOnlyOnChange=*/false, llvm::dbgs()); + auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>(); + kernel_pm.addNestedPass( + absl::make_unique(func.getType(), same_shape)); + + if (failed(pm.run(module))) { + return InternalError("Static knowledge propagation failed."); + } + return Status::OK(); +} +} // namespace + +StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( + llvm::StringRef tf_code, std::pair compute_capability, + llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors) { + mlir::MLIRContext context; + context.allowUnregisteredDialects(); // TODO(b/152572127) + mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); + + TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get())); + TF_RETURN_IF_ERROR( + xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors, + /*collapseParallelLoops=*/false)); + TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); + TF_RETURN_IF_ERROR( + PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + + mlir::OwningModuleRef kernel_module = + xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + if (!llvmModule) { + return InternalError("Could not translate MLIR module to NVVM"); + } + + llvmModule->setModuleIdentifier("acme"); + llvmModule->setDataLayout(xla::gpu::nvptx::kDataLayout); + + xla::HloModuleConfig config; + config.set_debug_options(xla::GetDebugOptionsFromFlags()); + + TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); + TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx( + llvmModule.get(), compute_capability, + config, libdevice_dir)); + VLOG(1) << ptx; + +#if GOOGLE_CUDA + return tensorflow::se::CompileGpuAsm( + std::get<0>(compute_capability), std::get<1>(compute_capability), + ptx.c_str(), xla::gpu::PtxOptsFromConfig(config)); +#else + return InternalError( + "GOOGLE_CUDA not defined. Did you specify --config=cuda ?"); +#endif +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h new file mode 100644 index 00000000000..47626ba9d0d --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- cubin_creator.h ------------------------------------------*- C++ -*-===// +// +// This file declares the function to compile a TF kernel function to a cubin. +// +//===----------------------------------------------------------------------===// +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { +namespace kernel_gen { +xla::StatusOr> GenerateCubinForTfCode( + llvm::StringRef tf_code, + std::pair compute_capability = {7, 5}, + llvm::ArrayRef tile_sizes = {16, 64}, + llvm::ArrayRef same_shape = {}, + llvm::ArrayRef unroll_factors = {}); +} // namespace kernel_gen +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_CUBIN_CREATOR_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc new file mode 100644 index 00000000000..8edc567e777 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc @@ -0,0 +1,118 @@ +// Copyright 2020 The TensorFlow Runtime Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- tf_to_cubin.cc -------------------------------------------*- C++ -*-===// +// +// This file implements the entry point to compile a tf op to a cubin file. +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +bool ParseStringList(std::string string_list, std::vector* result) { + result->clear(); + uint32_t item; + auto items = absl::StrSplit(string_list, ','); + for (const auto& item_str : items) { + if (!absl::SimpleAtoi(item_str, &item)) { + LOG(ERROR) << "Expected token " << item_str << " to be an integer"; + return false; + } + result->push_back(item); + } + return true; +} +} // namespace + +int main(int argc, char** argv) { + std::string output_file = "foo.bin"; + int32_t architecture = 50; + std::vector tile_sizes; + std::vector unroll_factors; + std::vector same_shape; + + auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) { + if (!ParseStringList(tile_sizes_str, &tile_sizes)) { + return false; + } + // Initialize with the default. + if (tile_sizes.empty()) { + tile_sizes.push_back(16); + tile_sizes.push_back(64); + } + return true; + }; + + auto parse_unroll_factors = + [&unroll_factors](std::string unroll_factors_str) { + return ParseStringList(unroll_factors_str, &unroll_factors); + }; + + auto parse_same_shape = [&same_shape](std::string same_shape_str) { + return ParseStringList(same_shape_str, &same_shape); + }; + + std::vector flag_list = { + tensorflow::Flag("output", &output_file, "output file"), + tensorflow::Flag("arch", &architecture, + "target architecture (e.g. 50 for sm_50)"), + tensorflow::Flag("tile_sizes", parse_tile_sizes, "16,64", + "tile sizes to use"), + tensorflow::Flag("unroll_factors", parse_unroll_factors, "", + "factors to unroll by, separated by commas"), + tensorflow::Flag("same_shape", parse_same_shape, "", + "arguments with same shape, separated by commas"), + }; + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain("usage", &argc, &argv); + if (!parse_ok) { + return 1; + } + + std::pair compute_capability(architecture / 10, + architecture % 10); + + auto cubin = tensorflow::kernel_gen::GenerateCubinForTfCode( + argv[1], compute_capability, tile_sizes, same_shape, unroll_factors); + + if (!cubin.ok()) { + LOG(ERROR) << cubin.status(); + return 1; + } + + std::vector cubin_data = cubin.ConsumeValueOrDie(); + + auto status = tensorflow::WriteStringToFile( + tensorflow::Env::Default(), output_file, + absl::string_view{reinterpret_cast(cubin_data.data()), + cubin_data.size()}); + + if (!status.ok()) { + LOG(ERROR) << status; + return 1; + } + + return 0; +} diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 598383d81ec..12334e463fa 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -23,7 +23,6 @@ package_group( "//tensorflow/compiler/xla/...", "//third_party/iree/...", "//third_party/mlir_edge/...", - "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) @@ -39,7 +38,7 @@ filegroup( "ir/lhlo_ops.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], ) @@ -133,6 +132,7 @@ cc_library( "transforms/legalize_tf_control_flow.cc", ], deps = [ + ":chlo_legalize_to_hlo", ":convert_op_folder", ":hlo", "//tensorflow/compiler/mlir/tensorflow", @@ -165,6 +165,7 @@ cc_library( ":mlir_hlo_builder", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", @@ -186,6 +187,7 @@ cc_library( "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", ], alwayslink = 1, @@ -239,8 +241,8 @@ cc_library( "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LoopOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -277,8 +279,8 @@ cc_library( "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LoopOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -397,9 +399,8 @@ cc_library( cc_library( name = "xla_hlo_to_lhlo_with_xla", - srcs = [ - "transforms/xla_hlo_to_lhlo_with_xla.cc", - ], + srcs = ["transforms/xla_hlo_to_lhlo_with_xla.cc"], + hdrs = ["transforms/xla_hlo_to_lhlo_with_xla.h"], deps = [ ":hlo", ":hlo_utils", @@ -588,6 +589,7 @@ cc_library( "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:shape_inference", @@ -717,6 +719,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", @@ -820,7 +823,7 @@ genrule( name = "operator_writer_inc", srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", ":ir/hlo_ops.td", ":ir/hlo_ops_base.td", diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.cc b/tensorflow/compiler/mlir/xla/attribute_importer.cc index 2d17127b075..201ec0d053f 100644 --- a/tensorflow/compiler/mlir/xla/attribute_importer.cc +++ b/tensorflow/compiler/mlir/xla/attribute_importer.cc @@ -117,7 +117,7 @@ mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers( builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()), Convert(kernel_spatial_dims, builder), builder->getI64IntegerAttr(dnums.output_batch_dimension()), - builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()), + builder->getI64IntegerAttr(dnums.output_feature_dimension()), Convert(output_spatial_dims, builder), builder->getContext()); } diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index c685cc296fd..dc801f64ede 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -139,6 +139,10 @@ StatusOr CreateDenseElementsAttrFromLiteral( return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::U64: return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::C64: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::C128: + return CreateDenseAttrFromLiteral(type, literal); default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type))); diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc index bc6842a617e..5322668aa2e 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -97,16 +97,12 @@ static Type GetBroadcastType(Type x, Type y, Type element_type, LogicalResult InferBroadcastBinaryOpReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, Type element_type, + DictionaryAttr attributes, Type element_type, SmallVectorImpl& inferedReturnShapes) { // Find broadcast_dimensions. - DenseIntElementsAttr broadcast_dimensions; - for (auto attr : attributes) { - if (attr.first == "broadcast_dimensions") { - broadcast_dimensions = attr.second.dyn_cast(); - break; - } - } + DenseIntElementsAttr broadcast_dimensions = + attributes.get("broadcast_dimensions") + .dyn_cast_or_null(); ShapedType lhs_type = operands[0].getType().dyn_cast(); ShapedType rhs_type = operands[1].getType().dyn_cast(); @@ -168,7 +164,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( LogicalResult BroadcastComplexOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { ShapedType lhs_type = operands[0].getType().dyn_cast(); if (!lhs_type) { @@ -191,7 +187,7 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( LogicalResult BroadcastCompareOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { Type element_type = IntegerType::get(1, context); return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, @@ -211,7 +207,7 @@ LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ LogicalResult Op::inferReturnTypeComponents( \ MLIRContext* context, Optional location, ValueRange operands, \ - ArrayRef attributes, RegionRange regions, \ + DictionaryAttr attributes, RegionRange regions, \ SmallVectorImpl& inferedReturnShapes) { \ return InferBroadcastBinaryOpReturnTypeComponents( \ context, location, operands, attributes, /*element_type=*/nullptr, \ diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td index a244985c9b5..f9672c1a95a 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td @@ -31,7 +31,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" def HLOClient_Dialect : Dialect { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index c9742ad5337..68eafb8b33e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_set.h" #include "llvm/ADT/APFloat.h" @@ -842,6 +843,51 @@ void ConcatenateOp::getCanonicalizationPatterns( results.insert(context); } +template +static Attribute foldConcatenateHelper(ConcatenateOp* op, + ArrayRef operands) { + auto axis = op->dimension().getLimitedValue(); + auto type = op->getType().cast(); + + SmallVector values; + auto shape = type.getShape(); + + size_t top_size = 1; + for (int i = 0; i < axis; i++) { + top_size = top_size * shape[i]; + } + + for (size_t i = 0; i < top_size; i++) { + for (auto operand : operands) { + DenseElementsAttr attr = operand.cast(); + size_t bottom_size = attr.getNumElements() / top_size; + auto iter = attr.getValues().begin() + i * bottom_size; + values.append(iter, iter + bottom_size); + } + } + + return DenseElementsAttr::get(type, values); +} + +static Attribute foldConcatenate(ConcatenateOp* op, + ArrayRef operands) { + for (auto operand : operands) { + if (!operand) return {}; + } + + auto type = op->getResult().getType().cast(); + auto etype = type.getElementType(); + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + return {}; +} + OpFoldResult ConcatenateOp::fold(ArrayRef operands) { if (getNumOperands() == 1) return getOperand(0); @@ -849,6 +895,10 @@ OpFoldResult ConcatenateOp::fold(ArrayRef operands) { if (!type.hasStaticShape()) return {}; auto axis = dimension().getLimitedValue(); + if (auto attr = foldConcatenate(this, operands)) { + return attr; + } + llvm::SmallVector new_operands; for (auto operand : getOperands()) { auto ty = operand.getType().cast(); @@ -1120,9 +1170,22 @@ OpFoldResult CopyOp::fold(ArrayRef operands) { return getOperand(); } //===----------------------------------------------------------------------===// OpFoldResult ReverseOp::fold(ArrayRef operands) { + auto input = operand(); + // No dimensions to reverse. - if (dimensions().getNumElements() == 0) return operand(); - return nullptr; + if (dimensions().getNumElements() == 0) return input; + + llvm::SmallVector new_dims; + new_dims.reserve(dimensions().getNumElements()); + + auto shaped_type = input.getType().cast(); + for (auto dim : dimensions().getValues()) { + if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) { + return nullptr; + } + } + + return input; } //===----------------------------------------------------------------------===// @@ -1190,7 +1253,7 @@ static LogicalResult Verify(SelectOp op) { // the return type based on operand type. LogicalResult SelectOp::inferReturnTypes( MLIRContext*, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { auto x_type = operands[1].getType(); auto y_type = operands[2].getType(); @@ -1412,6 +1475,53 @@ BINARY_BUILDER(XorOp); #undef BINARY_BUILDER +template +static Attribute BinaryFolder(Op* op, ArrayRef attrs) { + if (!attrs[0] || !attrs[1]) return {}; + if (op->broadcast_dimensions().hasValue()) return {}; + + DenseElementsAttr lhs = attrs[0].dyn_cast(); + DenseElementsAttr rhs = attrs[1].dyn_cast(); + if (!lhs || !rhs) return {}; + + ShapedType type = op->getType().template cast(); + if (!type.hasStaticShape()) { + return {}; + } + + Type etype = type.getElementType(); + + // Evaluate for integer values. + if (!etype.isa()) { + return {}; + } + + SmallVector values; + values.reserve(lhs.getNumElements()); + for (const auto zip : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip))); + } + + return DenseElementsAttr::get(type, values); +} + +#define BINARY_FOLDER(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return BinaryFolder>(this, attrs); \ + if (getElementTypeOrSelf(getType()).isa()) \ + return BinaryFolder>(this, attrs); \ + return {}; \ + } + +BINARY_FOLDER(AddOp, std::plus); +BINARY_FOLDER(SubOp, std::minus); +BINARY_FOLDER(MulOp, std::multiplies); + +#undef BINARY_FOLDER + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 16c9a7b4f05..f78ac7624d2 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -23,7 +23,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" include "tensorflow/compiler/mlir/xla/ir/hlo_utils.td" @@ -95,6 +95,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { // XLA unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + class HLO_UnaryElementwiseOp traits, Type TensorType>: HLO_Op { @@ -103,8 +104,7 @@ class HLO_UnaryElementwiseOp traits, let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, - ValueRange operands, ArrayRef attributes, - RegionRange regions, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } @@ -161,6 +161,16 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; +def HLO_ImagOp: HLO_Op< + "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_IsFiniteOp { @@ -188,6 +198,16 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; +def HLO_RealOp: HLO_Op< + "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; @@ -209,50 +229,14 @@ def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", BASE_HLO_SqrtOp; def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", - [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType], + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; -//===----------------------------------------------------------------------===// -// XLA complex unary elementwise op definitions. -//===----------------------------------------------------------------------===// -// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions - -def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>, - BASE_HLO_ComplexOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; - - let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); - let results = (outs HLO_ComplexTensor); - let hasFolder = 1; -} - -def HLO_ImagOp: HLO_Op< - "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); - let results = (outs HLO_FpTensor); - let hasFolder = 1; -} - -def HLO_RealOp: HLO_Op< - "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); - let results = (outs HLO_FpTensor); - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// - // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + class HLO_BinaryElementwiseOp traits> : HLO_Op { let arguments = (ins @@ -269,7 +253,7 @@ class HLO_BinaryElementwiseOp traits> : let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } @@ -286,22 +270,40 @@ class HLO_BinaryElementwiseOp traits> : } def HLO_AddOp : HLO_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp; + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp { + let hasFolder = 1; +} def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; +def HLO_ComplexOp: HLO_Op<"complex", + [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>, + BASE_HLO_ComplexOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; + + let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); + let results = (outs HLO_ComplexTensor); + let hasFolder = 1; +} + def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp; + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp { +} def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp; + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp { +} def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp; + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp { +} def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp { + let hasFolder = 1; +} def HLO_PowOp : HLO_BinaryElementwiseOp<"power", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; @@ -319,7 +321,9 @@ def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp { + let hasFolder = 1; +} //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index c087ffd1f40..b5de675f13f 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -150,15 +150,6 @@ class BASE_HLO_ClzOp { }]; } -class BASE_HLO_ComplexOp { - string summary = "Complex operator"; - - string description = [{ - Performs element-wise conversion of a pair of real and imaginary values to - a complex value. - }]; -} - class BASE_HLO_ConvertOp { string summary = "Convert operator"; @@ -400,6 +391,15 @@ class BASE_HLO_AddOp { }]; } +class BASE_HLO_ComplexOp { + string summary = "Complex operator"; + + string description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; +} + class BASE_HLO_DivOp { string summary = "Division operator"; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 6fcb2582002..db75bbd1f67 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -19,7 +19,7 @@ limitations under the License. #define LHLO_OPS include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffects.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" def LHLO_Dialect : Dialect { @@ -92,39 +92,30 @@ def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp; def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp; +def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp; def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; +def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp; def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp; def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; +def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp; + def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; -//===----------------------------------------------------------------------===// -// XLA complex unary elementwise op definitions. -//===----------------------------------------------------------------------===// -// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions - -def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { - let arguments = (ins Arg:$lhs, - Arg:$rhs, - Arg:$output); -} - -def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { - let arguments = (ins Arg:$input, - Arg:$output); -} - -def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { - let arguments = (ins Arg:$input, - Arg:$output); -} - //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// @@ -142,6 +133,12 @@ class LHLO_BinaryElementwiseOp traits> : def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp; +def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { + let arguments = (ins Arg:$lhs, + Arg:$rhs, + Arg:$output); +} + def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp; def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp; diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index cfa8c1b6bfc..461c357e509 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -55,6 +56,20 @@ static mlir::DenseIntElementsAttr GetI64ElementsAttr( return mlir::DenseIntElementsAttr::get(ty, mlir_values); } +static mlir::DenseIntElementsAttr ConvertPadding( + absl::Span> padding, + mlir::Builder* builder) { + llvm::SmallVector elements; + elements.reserve(padding.size() * 2); + for (const auto& vals : padding) { + elements.push_back(vals.first); + elements.push_back(vals.second); + } + auto ty = mlir::RankedTensorType::get( + {static_cast(padding.size()), 2}, builder->getIntegerType(64)); + return mlir::DenseIntElementsAttr::get(ty, elements); +} + MlirHloBuilder::~MlirHloBuilder() = default; StatusOr MlirHloBuilder::MakeXlaOp(mlir::Value val) { @@ -78,6 +93,31 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +StatusOr MlirHloBuilder::ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + mlir::ArrayAttr config_attr; + if (precision_config) + config_attr = ConvertPrecisionConfig(precision_config, &builder_); + auto op = builder_.create( + loc_, ty, GetValue(lhs), GetValue(rhs), + GetI64ElementsAttr(window_strides, &builder_), + ConvertPadding(padding, &builder_), + GetI64ElementsAttr(lhs_dilation, &builder_), + GetI64ElementsAttr(rhs_dilation, &builder_), + ConvertConvDimensionNumbers(dimension_numbers, &builder_), + builder_.getI64IntegerAttr(feature_group_count), + builder_.getI64IntegerAttr(batch_group_count), config_attr); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( @@ -100,6 +140,29 @@ StatusOr MlirHloBuilder::GatherInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::RngOpInternal( + RandomDistribution distribution, absl::Span parameters, + const Shape& shape) { + // TODO(hinsu): Introduce RngOp in the HLO dialect in MLIR and then RngUniform + // and RngNormal can be mapped to the new op. + std::string op_name; + if (distribution == xla::RandomDistribution::RNG_UNIFORM) { + op_name = "xla_hlo.rng_uniform"; + } else { + TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL) + << "Unexpected distribution: " << distribution; + op_name = "xla_hlo.rng_normal"; + } + + if (shape.is_dynamic()) + return Unimplemented("RngOp with dynamic dims not supported"); + llvm::SmallVector operands; + operands.append(parameters.begin(), parameters.end()); + operands.push_back( + ConstantLiteral(LiteralUtil::CreateR1(shape.dimensions()))); + return CreateOp(op_name, shape, operands); +} + StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) { @@ -154,15 +217,14 @@ StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp MlirHloBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs) { return ReportErrorOrReturn([&]() -> StatusOr { - return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}, /*attributes=*/{}); + return CreateOp(GetMlirOpName(binop), shape, {lhs, rhs}); }); } StatusOr MlirHloBuilder::AddOpWithShape( HloOpcode opcode, const Shape& shape, absl::Span operands) { return CreateOp(GetMlirOpName(opcode), shape, - llvm::makeArrayRef(operands.data(), operands.size()), - /*attributes=*/{}); + llvm::makeArrayRef(operands.data(), operands.size())); } XlaOp MlirHloBuilder::CreateToken() { @@ -220,6 +282,28 @@ StatusOr MlirHloBuilder::SliceInternal( GetI64ElementsAttr(strides, &builder_))); } +StatusOr MlirHloBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValues(start_indices), + GetI64ElementsAttr(slice_sizes, &builder_))); +} + +StatusOr MlirHloBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_ty, GetValue(operand), GetValue(update), + GetValues(start_indices))); +} + StatusOr MlirHloBuilder::PadInternal( const Shape& shape, XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) { diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index c0ef645a731..fc5baaee44d 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -101,9 +101,25 @@ class MlirHloBuilder : public XlaBuilder { // Returns the shape of the given op. StatusOr GetShapePtr(XlaOp op) const override; + // Creates the given op at the current location. + template + OpTy create(Args&&... args) { + return builder_.create(loc_, std::forward(args)...); + } + private: XlaOp ConstantLiteral(const LiteralSlice& literal) override; + StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) override; + StatusOr TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) override; @@ -113,6 +129,10 @@ class MlirHloBuilder : public XlaBuilder { const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, bool indices_are_sorted) override; + StatusOr RngOpInternal(RandomDistribution distribution, + absl::Span parameters, + const Shape& shape) override; + StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) override; @@ -155,6 +175,14 @@ class MlirHloBuilder : public XlaBuilder { absl::Span limit_indices, absl::Span strides) override; + StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) override; + + StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) override; + StatusOr PadInternal(const Shape& shape, XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) override; @@ -163,9 +191,10 @@ class MlirHloBuilder : public XlaBuilder { absl::Span elements) override; // Creates HLO dialect op and returns the result as an XlaOp. - StatusOr CreateOp(const std::string& op_name, const Shape& shape, - llvm::ArrayRef operands, - llvm::ArrayRef attributes); + StatusOr CreateOp( + const std::string& op_name, const Shape& shape, + llvm::ArrayRef operands, + llvm::ArrayRef attributes = {}); mlir::OpBuilder builder_; mlir::Location loc_; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index d92e3d25343..228a26b5abd 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -56,6 +56,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/stream_executor/lib/statusor.h" using ::stream_executor::port::StatusOr; @@ -907,6 +908,10 @@ namespace mlir { namespace { StatusOr CreateLiteralFromAttr(ElementsAttr attr) { + if (attr.isa()) + return tensorflow::errors::Unimplemented( + "Opaque elements attr not supported"); + xla::Shape shape = xla::TypeToShape(attr.getType()); #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ @@ -928,6 +933,8 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex) + ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex) case xla::PrimitiveType::F16: { llvm::SmallVector values; values.reserve(attr.getNumElements()); @@ -979,10 +986,26 @@ LogicalResult ConvertToHloModule::Lower( return LowerFunctionCall(&call_op, builder, &value_map); } + if (auto op = dyn_cast(inst)) { + Value operand = op.getOperand(); + auto ty = operand.getType().dyn_cast(); + // If this was a cast from a static shaped tensors, then it is a noop for + // export to HLO and we can use the operand. + if (!ty || !ty.hasStaticShape()) { + inst->emitOpError() + << "requires static shaped operand for HLO translation"; + return failure(); + } + + value_map[op.getResult()] = value_map[operand]; + return success(); + } + // TODO(jpienaar): This doesn't support layouts yet. if (matchPattern(inst, m_Constant(&const_attr))) { auto literal_or = CreateLiteralFromAttr(const_attr); - if (!literal_or.ok()) return inst->emitError("unsupported elemental type"); + if (!literal_or.ok()) + return inst->emitError(literal_or.status().ToString()); value_map[inst->getResult(0)] = xla::ConstantLiteral(builder, literal_or.ValueOrDie()); return success(); diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 5f28693c49d..30255586002 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -1,5 +1,50 @@ // RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure +// CHECK-LABEL: add_fold +func @add_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<[6, 8, 10, 12]> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_scalar_fold +func @add_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<1> : tensor<4xi64> + %1 = xla_hlo.constant dense<5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<6> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_fold_float +func @add_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: sub_scalar_fold +func @sub_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<5> : tensor<4xi64> + %1 = xla_hlo.constant dense<1> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<4> + %2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: multiply_scalar_fold +func @multiply_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<5> : tensor<4xi64> + %1 = xla_hlo.constant dense<3> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<15> + %2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + // CHECK-LABEL: concatenate_noop func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> @@ -43,6 +88,54 @@ func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> ten return %0 : tensor<0xf32> } +// CHECK-LABEL: concatenate_const_1D +func @concatenate_const_1D() -> tensor<4xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]> + %0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32> + %1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_const_1D_float +func @concatenate_const_1D_float() -> tensor<4xf32> { + // CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + + %0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + %1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: concatenate_const_2D_vertical +func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 1], [2, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32> + %1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: concatenate_const_2D_horizontal +func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 2], [1, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} // CHECK-LABEL: dynamic_slice_variable_start func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir index ce0243e416c..d67a7d09f7c 100644 --- a/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/xla/tests/chlo_infer_shape_type_methods.mlir @@ -6,8 +6,8 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]]) // CHECK: return %[[EXTENTS]] diff --git a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir index 2bc1e0c6852..7194f7034b5 100644 --- a/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -14,8 +14,8 @@ func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} @@ -31,8 +31,8 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor @@ -48,8 +48,8 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[ARG1:.+]]: tensor func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: %[[ARG0_S:.+]] = "shape.shape_of"(%[[ARG0]]) - // CHECK-DAG: %[[ARG1_S:.+]] = "shape.shape_of"(%[[ARG1]]) + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 262533bbf08..53296b257ae 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -hlo-legalize-to-lhlo %s -o - | FileCheck %s --dump-input-on-failure +// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -13,33 +13,42 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { + return %arg0 : tensor<4xf32> +} +// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) +// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + +// ----- + // CHECK-LABEL: func @func_op_long func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) - // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> - // CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() {temp = true} : memref<4xf32> %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) %5 = xla_hlo.multiply %2, %4 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) - // CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> - // CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> - // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () - // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> return %5 : tensor<4xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () } +// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) +// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> +// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> +// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () +// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> +// CHECK-NEXT: "xla_lhlo.terminator"() : () -> () // ----- @@ -47,20 +56,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) - // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> - // CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () "xla_lhlo.terminator"() : () -> () diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index ca8e64b9141..a856ee5e83c 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -222,6 +222,16 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_sin +func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: sin + %0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { @@ -274,8 +284,8 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func @broadcast -func @broadcast(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { +// CHECK-LABEL: func @broadcast_in_dim +func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { %0 = "xla_hlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> @@ -287,6 +297,22 @@ func @broadcast(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { // ----- +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one +func @broadcast_in_dim_with_one_to_one( + %operand: tensor<1xf32>) -> tensor<1x5xf32> { + %0 = "xla_hlo.broadcast_in_dim"(%operand) + {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + : (tensor<1xf32>) -> tensor<1x5xf32> + return %0 : tensor<1x5xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar @@ -444,3 +470,75 @@ func @reshape_multiple_collapse // CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> // CHECK-LABEL: func @reshape_multiple_collapse // CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] + +// ----- + +// CHECK-LABEL: func @convert_i32_to_f32 +func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> + return %result : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_i16_to_i32 +func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16): +// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @convert_i32_to_i16 +func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> + return %result : tensor<2x2xi16> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_f64 +func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> + return %result : tensor<2x2xf64> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f64 + +// ----- + +// CHECK-LABEL: func @convert_f64_to_f32 +func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> + return %result : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64): +// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir new file mode 100644 index 00000000000..149c0c94663 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -0,0 +1,307 @@ +// RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope --dump-input=fail %s + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.abs +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %abs = "xla_hlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %abs : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.add +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> +// CHECK: lhlo.and +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %res : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.ceil +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcomplex> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> +// CHECK: lhlo.complex +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) + return %res : tensor<1x2xcomplex> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> +// CHECK: lhlo.cosine +// CHECK-SAME: %[[ARG0]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.cosine"(%value0) : (tensor<1x2xcomplex>) -> tensor<1x2xcomplex> + return %res : tensor<1x2xcomplex> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.divide +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.exponential +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.log +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.maximum +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.minimum +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.multiply +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.negate +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> +func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> +// CHECK: lhlo.real +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.real"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) + return %res : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> +func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> +// CHECK: lhlo.imag +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.imag"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) + return %res : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> +// CHECK: lhlo.remainder +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %res : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.rsqrt +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {xla_lhlo.params = 2 +// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8> +func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.select +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]] +// CHECK-NEXT: return + %0 = "xla_hlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.sign +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.sqrt +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> +// CHECK: lhlo.subtract +// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] +// CHECK-NEXT: return + %res = "xla_hlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %res : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> +func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> +// CHECK: lhlo.tanh +// CHECK-SAME: %[[ARG0]], %[[VIEW]] + %res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %res : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir index cda1dc481a7..6a2b68adac3 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir @@ -8,7 +8,9 @@ // CHECK-SAME: ) { func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // The only expected instruction is a copy from the input into the output. - // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][][] : memref<16xi8> to memref<2x2xf32> + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C02:.*]] = constant 0 : index + // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C02]]][] : memref<16xi8> to memref<2x2xf32> // CHECK: xla_lhlo.copy // CHECK-SAME: %[[ARG0]], %[[OUTPUT]] return %value : tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index 08df9fd3808..3605e2a0d5c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -7,8 +7,8 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_basic // CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> -// CHECK: [[LHSSHAPE:%.*]] = "shape.shape_of"([[LHS]]) : (tensor<1x4x2xf32>) -> !shape.shape -// CHECK: [[RHSSHAPE:%.*]] = "shape.shape_of"([[RHS]]) : (tensor<3x2x4xf32>) -> !shape.shape +// CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32> +// CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32> // CHECK: [[CM2:%.*]] = constant -2 : i32 // CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) // CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) @@ -86,8 +86,8 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2 // CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]]) // CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]]) // CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) -// CHECK: "shape.shape_of"([[LHSCONJ]]) -// CHECK: "shape.shape_of"([[RHSCONJ]]) +// CHECK: shape.shape_of [[LHSCONJ]] +// CHECK: shape.shape_of [[RHSCONJ]] %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> return %0 : tensor<5x4xcomplex> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir index d2b4d269fef..0660af4ed1c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-full-conversion.mlir @@ -1,22 +1,24 @@ // RUN: tf-opt %s -xla-legalize-tf -split-input-file -verify-diagnostics +// expected-error@below{{The following operations cannot be legalized: tf.NoOp (count: 1); tf_executor.fetch (count: 1); tf_executor.graph (count: 1); tf_executor.island (count: 1); tf_executor.yield (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} func @tf_executor_graph_op() { - // expected-error@+1 {{failed to legalize operation 'tf_executor.graph'}} tf_executor.graph { %0 = tf_executor.island { + // expected-error@+1 {{'tf.NoOp' op is not legalizable}} "tf.NoOp"() {} : () -> () tf_executor.yield } tf_executor.fetch } return - } // ----- +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} func @tf_unknown_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // expected-error@+1 {{failed to legalize operation 'tf.OpA'}} + // expected-error@+1 {{'tf.OpA' op is not legalizable}} %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -27,3 +29,16 @@ func @tf_known_op(%arg0: tensor<2xi32>) -> tensor<2xi32> { %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } + +// ----- + +// expected-error@below{{The following operations cannot be legalized: tf.OpA (count: 1); tf.OpB (count: 2). These legalization failure(s) may be due to missing TF to HLO lowerings and/or unsupported attributes, etc.}} +// expected-error@below{{Emitting more detail about one op that failed to legalize...}} +func @tf_unknown_known_mix(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // expected-error@+1 {{'tf.OpA' op is not legalizable}} + %0 = "tf.OpA"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.OpB"(%0, %0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %2 = "tf.Add"(%1, %1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %3 = "tf.OpB"(%2, %2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %2: tensor<2xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index d2ce1d311f6..e8d5cfe997d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -50,6 +50,15 @@ func @dynamic_operand(%arg0: tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: unsupported_dtype +func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> { + // CHECK: tf.AddN + // expected-remark@+1 {{unsupported type: tensor<2x!tf.variant>}} + %0 = "tf.AddN"(%arg0, %arg0) : (tensor<2x!tf.variant>, tensor<2x!tf.variant>) -> tensor<2x!tf.variant> + + return %0 : tensor<2x!tf.variant> +} + // CHECK-LABEL: multiple_dialect_ops func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: xla_hlo.negate @@ -115,12 +124,68 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// TODO(hinsu): Add a test with variant type once one of the ops supporting -// the type is whitelisted. It should be rejected with unsupported type remark. +// CHECK-LABEL: func @const_inputs +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor, +func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> { -// TODO(hinsu): Add a test with uint8 type once one of the ops supporting the -// type is whitelisted. Unsigned types are not yet added to the HLO dialect so -// it should return an error. See b/130356985 + // CHECK: "xla_hlo.pad"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64> + // CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64> + // CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64> + + %0 = xla_hlo.constant dense<[2, 1]> : tensor<2xi32> + %1 = xla_hlo.constant dense<[1, 2]> : tensor<2xi32> + %2 = xla_hlo.constant dense<[1, 0]> : tensor<2xi32> + %3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64> + return %3 : tensor<6x5xf64> +} + +func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> { + // expected-remark@+1 {{lowering requires operand #2 to be a constant}} + %0 = "tf.XlaPad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x2xf64>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64> + return %0 : tensor<6x5xf64> +} + +// CHECK-LABEL: dynamic_result_type +func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> { + // CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: tensor_cast %0 : tensor<2xf32> to tensor<*xf32> + %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<*xf32> + + // return %[[RESULT]] + return %0 : tensor<*xf32> +} + +func @truncated_normal() -> tensor<2x2xf32> { + // CHECK-NOT: tf.TruncatedNormal + %0 = xla_hlo.constant dense<[2, 2]> : tensor<2xi32> + %1 = "tf.TruncatedNormal"(%0) {T = i32, device = "", dtype = f32, seed = 0 : i64, seed2 = 1950157571 : i64} : (tensor<2xi32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// CHECK-LABEL: dynamic_update_slice +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32> +func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> { + + // CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor + + // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]]) + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> + // CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32> + // CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor + + // CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]]) + + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32> + return %0: tensor<3x4xi32> +} // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index e15101a165e..450910b2e4d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -426,6 +426,8 @@ func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tenso //===----------------------------------------------------------------------===// // Binary op legalizations. +// Most of these expand from the same pattern. Full semantics are +// verified for tf.Add and pattern application only for the rest. //===----------------------------------------------------------------------===// // CHECK-LABEL: func @add @@ -439,19 +441,49 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { } // CHECK-LABEL: func @broadcast_add +// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check +// patterns unambiguous and more interesting (once broadcastable trait is +// fixed upstream). func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } // CHECK-LABEL: func @broadcast_multi_dim_add +// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream +// broadcastable bug is fixed (helps make the CHECK matching unambiguous) func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0: tensor<4x4x4x4xi32> } +// CHECK-LABEL: func @add_dynamic +func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %4, %5 : tensor + %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0: tensor +} + // CHECK-LABEL: func @div func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> @@ -460,13 +492,6 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { return %0: tensor<2xi32> } -// CHECK-LABEL: func @broadcast_div -func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - // CHECK-LABEL: func @shift_left func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> @@ -474,13 +499,6 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { return %0 : tensor<4xi32> } -// CHECK-LABEL: func @div_dynamic -func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Div"(%arg0, %arg1) : (tensor, tensor) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @div_unranked func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor { // CHECK: tf.Div @@ -510,13 +528,6 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { return %0: tensor<2xi32> } -// CHECK-LABEL: func @broadcast_mul -func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - // CHECK-LABEL: func @real_div func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> @@ -524,13 +535,6 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { return %0: tensor<2xi32> } -// CHECK-LABEL: func @broadcast_real_div -func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - // CHECK-LABEL: func @sub func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> @@ -539,13 +543,6 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { return %0: tensor<2xi32> } -// CHECK-LABEL: func @broadcast_sub -func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - // CHECK-LABEL: func @shift_right func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> @@ -553,13 +550,6 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { return %0 : tensor<4xi32> } -// CHECK-LABEL: func @broadcast_shift_right -func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - // CHECK: "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - return %0 : tensor<2x4xi32> -} - // CHECK-LABEL: func @shift_right_unsigned func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { // CHECK: tf.RightShift @@ -581,20 +571,6 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// CHECK-LABEL: func @and_broadcast -func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.and" - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @and_dynamic -func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - // CHECK-NEXT: "xla_hlo.and" - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @and_unranked func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { // CHECK: tf.LogicalAnd @@ -609,20 +585,6 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// CHECK-LABEL: func @or_broadcast -func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @or_dynamic -func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @bitwise_or func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: xla_hlo.or @@ -630,20 +592,6 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { return %0: tensor<4xi32> } -// CHECK-LABEL: func @bitwise_or_broadcast -func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> - return %0: tensor<1x4xi8> -} - -// CHECK-LABEL: func @bitwise_or_dynamic -func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @bitwise_and func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: xla_hlo.and @@ -651,20 +599,6 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { return %0: tensor<4xi32> } -// CHECK-LABEL: func @bitwise_and_broadcast -func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> - return %0: tensor<1x4xi8> -} - -// CHECK-LABEL: func @bitwise_and_dynamic -func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @pow func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-NEXT: xla_hlo.power @@ -672,13 +606,6 @@ func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0: tensor<2xf32> } -// CHECK-LABEL: func @pow_dynamic -func @pow_dynamic(%arg0: tensor) -> tensor { - // CHECK-NEXT: xla_hlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor, tensor) -> tensor - return %0: tensor -} - // CHECK-LABEL: func @diag_part // CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { @@ -862,6 +789,8 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { //===----------------------------------------------------------------------===// // Equality op legalizations. +// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are +// verified for tf.Equal and pattern application only for tf.NotEqual //===----------------------------------------------------------------------===// // CHECK-LABEL: func @equal @@ -873,14 +802,26 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-LABEL: func @equal_dynamic func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} %0 = "tf.Equal"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0: tensor } // CHECK-LABEL: func @equal_broadcast func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0: tensor<1x2xi1> } @@ -927,70 +868,42 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// CHECK-LABEL: func @notequal_dynamic -func @notequal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @notequal_broadcast -func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @notequal_broadcast_no_incompatible_shapes_error -func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @notequal_incompatible_shape_broadcastable -func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor<1xi32>) -> tensor - return %0: tensor -} - -// CHECK-LABEL: func @notequal_incompatible_shape_dynamic -func @notequal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @notequal_incompatible_shape_both_dynamic -func @notequal_incompatible_shape_both_dynamic(%arg0: tensor, %arg1: tensor) -> tensor<*xi1> { - // CHECK-NEXT: "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.NotEqual"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor, tensor) -> tensor<*xi1> - return %0: tensor<*xi1> -} - //===----------------------------------------------------------------------===// // Compare op legalizations. +// These expand from the same pattern. Full semantics are checked for +// tf.Greater. Others just check that the pattern applied. //===----------------------------------------------------------------------===// // CHECK-LABEL: func @greater func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @broadcast_greater func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0: tensor<1x2xi1> } // CHECK-LABEL: func @greater_dynamic -func @greater_dynamic(%arg0: tensor) -> tensor { - // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor, tensor) -> tensor +func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } @@ -1008,13 +921,6 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// CHECK-LABEL: func @broadcast_greater_equal -func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - // CHECK-LABEL: func @less func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} @@ -1022,13 +928,6 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// CHECK-LABEL: func @broadcast_less -func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - // CHECK-LABEL: func @less_equal func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} @@ -1036,13 +935,6 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { return %0: tensor<2xi1> } -// CHECK-LABEL: func @broadcast_less_equal -func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - //===----------------------------------------------------------------------===// // Complex op legalizations. @@ -1596,6 +1488,44 @@ func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32>) -> return %0, %1 : tensor, tensor } + +//===----------------------------------------------------------------------===// +// ReverseV2 op legalization. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @reverse_func_32 +func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + return %reversed : tensor<5xi32> +} + +// CHECK-LABEL: @reverse_func_64 +func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + return %reversed : tensor<5xi32> +} + +// CHECK-LABEL: @reverse_func_neg +func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { + %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = "xla_hlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> + + // CHECK: return [[VAL]] : tensor<5x5xi32> + return %reversed : tensor<5x5xi32> +} + //===----------------------------------------------------------------------===// // StatefulPartitionedCall op legalization. //===----------------------------------------------------------------------===// @@ -2205,13 +2135,6 @@ func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// CHECK-LABEL: func @round -func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: "xla_hlo.round_nearest_afz"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - return %0 : tensor<2xf32> -} - // CHECK-LABEL: func @rsqrt func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -3720,11 +3643,11 @@ func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> { - // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5x3xf32> +func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> { + // CHECK: "xla_hlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32> %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5x3xf32> - return %1 : tensor<16x2x5x3xf32> + %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> + return %1 : tensor<16x2x5xf32> } // CHECK-LABEL: @gather_v2_dynamic @@ -4081,6 +4004,41 @@ func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { return %0 : tensor<4x16xf32> } +// CHECK-LABEL: inplace_update_one +func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[UPDATE:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> + + // CHECK: return [[UPDATE]] + return %0 : tensor<8x4xf32> +} + +// CHECK-LABEL: inplace_update_three +func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { + // CHECK-DAG: [[CST:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE2:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE3:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SLICE4:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE5:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[SLICE6:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[RESHAPE2:%.+]] = "xla_hlo.reshape"([[SLICE2]]) + // CHECK-DAG: [[RESHAPE3:%.+]] = "xla_hlo.reshape"([[SLICE3]]) + // CHECK-DAG: [[UPDATE1:%.+]] = "xla_hlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE2:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) + // CHECK-DAG: [[UPDATE3:%.+]] = "xla_hlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) + %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> + + // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> + return %0 : tensor<8x8x4xf32> +} + + // CHECK-LABEL: xla_dynamic_update_slice func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { // CHECK: [[SLICE0:%.+]] = "xla_hlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> @@ -4103,6 +4061,21 @@ func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg return %0 : tensor<4xf32> } +//===----------------------------------------------------------------------===// +// AllToAll op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @alltoall_basic +func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { + %group_assignment = "tf.Const" () { + value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> + } : () -> tensor<3x4xi32> + %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32> + // CHECK: xla_hlo.all_to_all + // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> + return %result : tensor<10xf32> +} + //===----------------------------------------------------------------------===// // Cumsum op legalizations. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir index 013748fea28..99b1766e73c 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir @@ -24,9 +24,9 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // CHECK-LABEL: func @fusion // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic @@ -36,9 +36,9 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: addf // TILED: linalg.generic @@ -46,8 +46,8 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // PLOOP-LABEL: func @fusion // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: addf // PLOOP: linalg.generic @@ -94,9 +94,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // CHECK-LABEL: func @fusion // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: linalg.generic // CHECK: subf @@ -107,9 +107,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: linalg.generic // TILED: subf @@ -118,8 +118,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // PLOOP-LABEL: func @fusion_of_three // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: linalg.generic // PLOOP: subf @@ -147,11 +147,11 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // CHECK-LABEL: func @fusion_4d // CHECK: %[[C1:.*]] = constant 1 // CHECK-NOT: linalg.generic -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK: loop.for {{.*}} step %[[C1]] -// CHECK-NOT: loop.for +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for // CHECK: linalg.generic // CHECK: addf // CHECK: linalg.generic @@ -161,9 +161,9 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // TILED-DAG: %[[C2:.*]] = constant 2 // TILED-DAG: %[[C3:.*]] = constant 3 // TILED-NOT: linalg.generic -// TILED: loop.for {{.*}} step %[[C2]] -// TILED: loop.for {{.*}} step %[[C3]] -// TILED-NOT: loop.for +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for // TILED: linalg.generic // TILED: addf // TILED: linalg.generic @@ -171,8 +171,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // PLOOP-LABEL: func @fusion_4d // PLOOP-NOT: linalg.generic -// PLOOP: loop.parallel -// PLOOP-NOT: loop.parallel +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel // PLOOP: linalg.generic // PLOOP: addf // PLOOP: linalg.generic diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir index 5b763cde2ed..c640b395f4d 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-select-and-scatter.mlir @@ -50,19 +50,19 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // Parallel loop to initialize the output buffer. // CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref -// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) { // CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // Parallel loop over source buffer to compute scattered values. -// CHECK: loop.parallel ([[II:%.*]], [[JJ:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[II:%.*]], [[JJ:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { // Window loop w.r.t. first dim. // CHECK: [[SEL_RES_I:%.*]]:4 -// CHECK-SAME: = loop.for [[WIN_I:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: = scf.for [[WIN_I:%.*]] = [[C0]] to [[C3]] step [[C1]] // CHECK-SAME: iter_args( // CHECK-SAME: [[SEL_I_0:%.*]] = [[C0]], [[SEL_J_0:%.*]] = [[C0]], // CHECK-SAME: [[SEL_VAL_0:%.*]] = [[C0_F32]], @@ -71,7 +71,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // Window loop w.r.t. second dim. // CHECK: [[SEL_RES_J:%.*]]:4 -// CHECK-SAME: = loop.for [[WIN_J:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: = scf.for [[WIN_J:%.*]] = [[C0]] to [[C3]] step [[C1]] // CHECK-SAME: iter_args( // CHECK-SAME: [[SEL_I:%.*]] = [[SEL_I_0]], [[SEL_J:%.*]] = [[SEL_J_0]], // CHECK-SAME: [[SEL_VAL:%.*]] = [[SEL_VAL_0]], @@ -102,14 +102,14 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // be applied, current selected ivs (SEL_I, SEL_J) and value (SEL_VAL) are // returned in that case. // CHECK: [[IF_INBOUNDS_RES:%.*]]:4 -// CHECK-SAME: = loop.if [[INBOUNDS_1]] -> (index, index, f32, i1) { +// CHECK-SAME: = scf.if [[INBOUNDS_1]] -> (index, index, f32, i1) { // INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true // CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]] // CHECK: [[IF_INIT_RES:%.*]]:4 - // CHECK-SAME: = loop.if [[SEL_INIT]] -> (index, index, f32, i1) { + // CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) { // INIT-THEN-BODY, i.e. INBOUNDS == true and INIT = true @@ -133,40 +133,40 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // Depending on PRED, return ARG ivs & elem or current select ivs and value. - // CHECK: [[IF_PRED_RES:%.*]]:4 = loop.if [[PRED]] - // CHECK: loop.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], [[CTRUE]] + // CHECK: [[IF_PRED_RES:%.*]]:4 = scf.if [[PRED]] + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], [[CTRUE]] // CHECK: } else { - // CHECK: loop.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], [[SEL_INIT]] + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], [[SEL_INIT]] // CHECK: } // INIT-THEN-BODY yield. - // CHECK: loop.yield [[IF_PRED_RES]]#0, [[IF_PRED_RES]]#1, + // CHECK: scf.yield [[IF_PRED_RES]]#0, [[IF_PRED_RES]]#1, // CHECK-SAME: [[IF_PRED_RES]]#2, [[IF_PRED_RES]]#3 // INIT-ELSE-BODY, i.e. if INBOUNDS == TRUE and INIT == FALSE, returns ARG // ivs and element without computing Select function. - // CHECK: loop.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], // CHECK-SAME: [[CTRUE]] : index, index, f32, i1 // CHECK: } // INBOUNDS-THEN-BODY yield. - // CHECK: loop.yield [[IF_INIT_RES]]#0, [[IF_INIT_RES]]#1, [[IF_INIT_RES]]#2, + // CHECK: scf.yield [[IF_INIT_RES]]#0, [[IF_INIT_RES]]#1, [[IF_INIT_RES]]#2, // CHECK-SAME: [[IF_INIT_RES]]#3 : index, index, f32, i1 // CHECK: } // INBOUNDS-ELSE-REGION, i.e. if INBOUNDS == FALSE // We are in the pad area, return current iter_args. - // CHECK: loop.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], // CHECK-SAME: [[SEL_INIT]] : index, index, f32, i1 // CHECK: } // Window loop w.r.t. second dim yield. -// CHECK: loop.yield [[IF_INBOUNDS_RES]]#0, [[IF_INBOUNDS_RES]]#1, +// CHECK: scf.yield [[IF_INBOUNDS_RES]]#0, [[IF_INBOUNDS_RES]]#1, // CHECK-SAME: [[IF_INBOUNDS_RES]]#2, [[IF_INBOUNDS_RES]]#3 // CHECK: } // Window loop w.r.t. first dim yield. -// CHECK: loop.yield [[SEL_RES_J]]#0, [[SEL_RES_J]]#1, [[SEL_RES_J]]#2, +// CHECK: scf.yield [[SEL_RES_J]]#0, [[SEL_RES_J]]#1, [[SEL_RES_J]]#2, // CHECK-SAME: [[SEL_RES_J]]#3 : index, index, f32, i1 // CHECK: } @@ -196,4 +196,4 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: atomic_yield [[RES]] : f32 // Parallel loop over source buffer yield -// CHECK: loop.yield +// CHECK: scf.yield diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir index 4d878cee6f4..16ffbf241b0 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-gpu.mlir @@ -22,7 +22,7 @@ func @reduce(%arg: memref<100x10xf32>, // CHECK-DAG: %[[LB:.*]] = constant 0 : index // CHECK-DAG: %[[UB:.*]] = constant 10 : index // CHECK-DAG: %[[STEP:.*]] = constant 1 : index -// CHECK: loop.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref // CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref // CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index b80d5ba6755..bb8010b520c 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -3,7 +3,7 @@ // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @element_wise func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { + %result: memref<2x2xf32>) { "xla_lhlo.add"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -16,8 +16,9 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // ----- // CHECK-LABEL: func @element_wise_with_dynamic_shape -func @element_wise_with_dynamic_shape(%lhs: memref, %rhs: memref, - %result: memref) { +func @element_wise_with_dynamic_shape(%lhs: memref, + %rhs: memref, + %result: memref) { "xla_lhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return @@ -31,22 +32,22 @@ func @element_wise_with_dynamic_shape(%lhs: memref, %rhs: memref, %rhs: memref, - %result: memref) { + %result: memref) { + "xla_lhlo.add"(%lhs, %rhs, %result) + : (memref, memref, memref) -> () + return +} // CHECK: %[[LHS:.*]] = load // CHECK: %[[RHS:.*]] = load // CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] // CHECK: store %[[RES]] // CHECK-NEXT: return - "xla_lhlo.add"(%lhs, %rhs, %result) - : (memref, memref, memref) -> () - return -} // ----- // CHECK-LABEL: func @minf func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { + %result: memref<2x2xf32>) { "xla_lhlo.minimum"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -61,7 +62,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @maxi func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, - %result: memref<2x2xi32>) { + %result: memref<2x2xi32>) { "xla_lhlo.maximum"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return @@ -89,8 +90,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // ----- // CHECK-LABEL: func @exp -func @exp(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { +func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -103,10 +103,8 @@ func @exp(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @log -func @log(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.log"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -117,10 +115,8 @@ func @log(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @copy -func @copy(%input: memref<2x4x8xf32>, - %result: memref<2x4x8xf32>) { - "xla_lhlo.copy"(%input, %result) - : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () +func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { + "xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () return } // CHECK: linalg.generic @@ -131,7 +127,7 @@ func @copy(%input: memref<2x4x8xf32>, // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xi1>) { + %result: memref<2x2xi1>) { "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> () return @@ -146,7 +142,8 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () + "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} + : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () return } // CHECK: linalg.generic @@ -157,10 +154,10 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // ----- // CHECK-LABEL: func @select -func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { +func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, + %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { "xla_lhlo.select"(%pred, %lhs, %rhs, %result) - : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () + : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -184,20 +181,13 @@ func @iota(%out: memref<7x10xf32>) { // ----- -// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func @iota -func @iota(%out: memref<7x10xi64>) { - "xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xi64>) -> () - return -} - -// ----- - // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { - "xla_lhlo.broadcast"(%operand, %result) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (memref, memref<4x2x1xf32>) -> () + "xla_lhlo.broadcast"(%operand, %result) { + broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> + } : (memref, memref<4x2x1xf32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -209,8 +199,11 @@ func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> // CHECK-LABEL: func @broadcast -func @broadcast(%operand: memref<4x?x16xf32>, %result: memref<4x2x1x4x?x16xf32>) { - "xla_lhlo.broadcast"(%operand, %result) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () +func @broadcast(%operand: memref<4x?x16xf32>, + %result: memref<4x2x1x4x?x16xf32>) { + "xla_lhlo.broadcast"(%operand, %result) { + broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> + } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -221,12 +214,12 @@ func @broadcast(%operand: memref<4x?x16xf32>, %result: memref<4x2x1x4x?x16xf32>) // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func @dynamic_broadcast -func @dynamic_broadcast(%operand: memref, - %result: memref) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) - {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} - : (memref, memref) -> () +// CHECK-LABEL: func @dynamic_broadcast_in_dim +func @dynamic_broadcast_in_dim(%operand: memref, + %result: memref) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> + } : (memref, memref) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -237,11 +230,12 @@ func @dynamic_broadcast(%operand: memref, // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func @broadcast -func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) - {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} - : (memref<5x7x1xf32>, memref<7x10x6x4x5xf32>) -> () +// CHECK-LABEL: func @broadcast_in_dim_with_expansion +func @broadcast_in_dim_with_expansion(%operand: memref<5x7x1xf32>, + %result: memref<7x10x6x4x5xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> + } : (memref<5x7x1xf32>, memref<7x10x6x4x5xf32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -252,11 +246,12 @@ func @broadcast(%operand: memref<5x7x1xf32>, %result: memref<7x10x6x4x5xf32>) { // CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func @broadcast_scalar -func @broadcast_scalar(%operand: memref, %result: memref<7x10x6xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) - {broadcast_dimensions = dense<[]> : tensor<0xi64>} - : (memref, memref<7x10x6xf32>) -> () +// CHECK-LABEL: func @broadcast_in_dim_scalar +func @broadcast_in_dim_scalar(%operand: memref, + %result: memref<7x10x6xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[]> : tensor<0xi64> + } : (memref, memref<7x10x6xf32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]] @@ -265,9 +260,26 @@ func @broadcast_scalar(%operand: memref, %result: memref<7x10x6xf32>) { // ----- +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one +func @broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[0]> : tensor<1xi64> + } : (memref<1xf32>, memref<1x5xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + // CHECK-LABEL: func @constant func @constant(%value: memref) { - "xla_lhlo.constant"(%value) {value = dense<10> : tensor} : (memref) -> () + "xla_lhlo.constant"(%value) { + value = dense<10> : tensor + } : (memref) -> () return } // CHECK: %[[CONSTANT:.*]] = constant 10 : i32 @@ -275,11 +287,9 @@ func @constant(%value: memref) { // ----- -// CHECK-LABEL: func @abs -func @abs(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.abs"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +// CHECK-LABEL: func @absf +func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -289,10 +299,10 @@ func @abs(%input: memref<2x2xf32>, // ----- -func @abs(%input: memref<2x2xi32>, +// CHECK-LABEL: func @absi +func @absi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.abs"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi32>) -> () + "xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -307,10 +317,8 @@ func @abs(%input: memref<2x2xi32>, // ----- // CHECK-LABEL: func @ceil -func @ceil(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -321,10 +329,8 @@ func @ceil(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @convert_i32_to_f32 -func @convert_i32_to_f32(%input: memref<2x2xi32>, - %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi32>, memref<2x2xf32>) -> () +func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -337,8 +343,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, // CHECK-LABEL: func @convert_i16_to_i32 func @convert_i16_to_i32(%input: memref<2x2xi16>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi16>, memref<2x2xi32>) -> () + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -349,10 +354,8 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>, // ----- // CHECK-LABEL: func @convert_i32_to_i16 -func @convert_i32_to_i16(%input: memref<2x2xi32>, - %result: memref<2x2xi16>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi16>) -> () +func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () return } // CHECK: linalg.generic @@ -363,10 +366,8 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, // ----- // CHECK-LABEL: func @convert_f32_to_f64 -func @convert_f32_to_f64(%input: memref<2x2xf32>, - %result: memref<2x2xf64>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf64>) -> () +func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () return } // CHECK: linalg.generic @@ -377,10 +378,8 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @convert_f64_to_f32 -func @convert_f64_to_f32(%input: memref<2x2xf64>, - %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xf64>, memref<2x2xf32>) -> () +func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -391,10 +390,8 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, // ----- // CHECK-LABEL: func @convert_i32_to_i32 -func @convert_i32_to_i32(%input: memref<2x2xi32>, - %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi32>) -> () +func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -404,10 +401,8 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, // ----- // CHECK-LABEL: func @convert_f32_to_f32 -func @convert_f32_to_f32(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -416,11 +411,22 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, // ----- +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) + : (memref<2x2xf32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + // CHECK-LABEL: func @cos -func @cos(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -430,28 +436,37 @@ func @cos(%input: memref<2x2xf32>, // ----- -// CHECK-LABEL: func @neg -func @neg(%input: memref<2x2xf32>, +// CHECK-LABEL: func @sin +func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.negate"(%input, %result) + "xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @negf +func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): // CHECK-NEXT: %[[RESULT:.*]] = negf %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- -// CHECK-LABEL: func @neg -func @neg(%input: memref<2x2xi32>, - %result: memref<2x2xi32>) { - "xla_lhlo.negate"(%input, %result) - : (memref<2x2xi32>, memref<2x2xi32>) -> () +// CHECK-LABEL: func @negi +func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + "xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } - // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]): // CHECK-NEXT: %[[L0:.*]] = constant 0 : i32 @@ -462,7 +477,7 @@ func @neg(%input: memref<2x2xi32>, // CHECK-LABEL: func @rem func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, - %result: memref<2x2xf32>) { + %result: memref<2x2xf32>) { "xla_lhlo.remainder"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return @@ -475,10 +490,8 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // ----- // CHECK-LABEL: func @rsqrt -func @rsqrt(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.rsqrt"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -489,10 +502,8 @@ func @rsqrt(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @sign -func @sign(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.sign"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -504,10 +515,8 @@ func @sign(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @sqrt -func @sqrt(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.sqrt"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -518,10 +527,8 @@ func @sqrt(%input: memref<2x2xf32>, // ----- // CHECK-LABEL: func @tanh -func @tanh(%input: memref<2x2xf32>, - %result: memref<2x2xf32>) { - "xla_lhlo.tanh"(%input, %result) - : (memref<2x2xf32>, memref<2x2xf32>) -> () +func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -529,6 +536,48 @@ func @tanh(%input: memref<2x2xf32>, // CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 +// ----- + +// CHECK-LABEL: func @complex +func @complex(%real: memref<2x2xf32>, + %imag: memref<2x2xf32>, + %cplx: memref<2x2xcomplex>) { + "xla_lhlo.complex"(%real, %imag, %cplx) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex): +// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @real +func @real(%cplx: memref<2x2xcomplex>, + %real: memref<2x2xf32>) { + "xla_lhlo.real"(%cplx, %real) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[REAL_OUT:.*]]: f32): +// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[REAL]] : f32 + +// ----- + +// CHECK-LABEL: func @imag +func @imag(%cplx: memref<2x2xcomplex>, + %imag: memref<2x2xf32>) { + "xla_lhlo.imag"(%cplx, %imag) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[IMAG_OUT:.*]]: f32): +// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[IMAG]] : f32 // ----- @@ -558,7 +607,8 @@ func @slice(%operand: memref, %result: memref) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) : (memref<12x1x42xi32>, memref<12x42xi32>) -> () + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x1x42xi32>, memref<12x42xi32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -569,7 +619,8 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reshape_4D_2D func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -580,7 +631,8 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @reshape_2D_4D func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir index cb169e060ef..32c367f97d6 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-parallel-loops.mlir @@ -22,13 +22,13 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK-DAG: [[C10:%.*]] = constant 10 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -37,12 +37,12 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: loop.yield +// CHECK: scf.yield // ----- @@ -66,10 +66,10 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C100:%.*]] = constant 100 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[I:%.*]]) = ([[C0]]) +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]]) // CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}} -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -78,9 +78,9 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] +// CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] // ----- @@ -107,13 +107,13 @@ func @dynamic_reduce(%arg: memref, // CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], 1 : memref // CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], 2 : memref // CHECK: [[INIT:%.*]] = load [[INIT_BUF]] -// CHECK: loop.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel ([[J:%.*]]) = +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = // CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { // CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] // CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -122,12 +122,12 @@ func @dynamic_reduce(%arg: memref, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] -// CHECK: loop.yield +// CHECK: scf.yield // ----- @@ -158,9 +158,9 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK-DAG: [[C56:%.*]] = constant 56 : index // CHECK-DAG: [[C112:%.*]] = constant 112 : index // CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref -// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { -// CHECK: [[REDUCTION_RESULT:%.*]] = loop.parallel +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel // CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]]) // CHECK-SAME: init ([[INIT]]) -> f32 { @@ -177,15 +177,15 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]] // CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]] -// CHECK: [[ELEM_TO_REDUCE:%.*]] = loop.if [[IN_BOUNDS_1]] -> (f32) { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) { // CHECK: [[OPERAND_ELEM:%.*]] = // CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] -// CHECK: loop.yield [[OPERAND_ELEM]] : f32 +// CHECK: scf.yield [[OPERAND_ELEM]] : f32 // CHECK: } else { -// CHECK: loop.yield [[INIT]] : f32 +// CHECK: scf.yield [[INIT]] : f32 // CHECK: } -// CHECK: loop.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { // CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): // CHECK: [[ELEM_BUF:%.*]] = alloc() : memref // CHECK: [[ACC_BUF:%.*]] = alloc() : memref @@ -194,12 +194,12 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: store [[ACC]], [[ACC_BUF]][] : memref // CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref -// CHECK: loop.reduce.return [[ACC_RESULT]] : f32 +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] -// CHECK: loop.yield +// CHECK: scf.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index 4050340ce49..2340650dda8 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -20,6 +20,17 @@ func @addBroadcastLhs(%arg0: tensor<4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x // ----- +// CHECK-LABEL: @addBroadcastEqual +func @addBroadcastEqual(%arg0: tensor<4x1xf32>, %arg1: tensor<1x4xf32>) -> tensor<4x4xf32> { + // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4x4xf32> + %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// ----- + // CHECK-LABEL: @addBroadcastMultidimension func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 3650307ea94..15fa91588a5 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -294,6 +294,12 @@ func @main() { // CHECK: f16[4] constant({1, -4, -65504, 0.015625} %cst_8 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> + // CHECK: c64[] constant((1, 0)) + %cst_9 = constant dense<(1.000000e+00,0.000000e+00)> : tensor> + + // CHECK: c128[] constant((1, 0)) + %cst_10 = constant dense<(1.000000e+00,0.000000e+00)> : tensor> + return } @@ -1038,3 +1044,16 @@ func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = u8[4] parameter(0) // ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]]) + +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<4xi32>) -> (tensor<*xi32>) { + %0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + %1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32> + return %1 : tensor<*xi32> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = s32[4] parameter(0) +// ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir new file mode 100644 index 00000000000..97c53cb5f9f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/export_errors.mlir @@ -0,0 +1,7 @@ +// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s + +// CHECK: Opaque elements attr not supported +func @main() { + %0 = "tf.Const"() {value = opaque<"tf", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32> + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index d1133057544..207a8f2eabc 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -212,10 +212,14 @@ add { // CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> %constant.3 = bf16[4] constant({1, 2, 3, 4}) + // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor> + %constant.4 = c64[] constant((1, 0)) + + // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor> + %constant.5 = c128[] constant((1, 0)) + // CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16> - ROOT %constant.4 = f16[4] constant({1, -4, -65504, 0.015625}) - - + ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -244,8 +248,8 @@ add { // CHECK-SAME: kernel_input_feature_dimension = 2 : i64 // CHECK-SAME: kernel_output_feature_dimension = 3 : i64 // CHECK-SAME: kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64> - // CHECK-SAME: output_batch_dimension = 0 : i64 - // CHECK-SAME: output_feature_dimension = 3 : i64 + // CHECK-SAME: output_batch_dimension = 3 : i64 + // CHECK-SAME: output_feature_dimension = 0 : i64 // CHECK-SAME: output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> // CHECK-SAME: } // CHECK-SAME: feature_group_count = 1 : i64 @@ -255,11 +259,11 @@ add { // CHECK-SAME: rhs_dilations = dense<[2, 3]> : tensor<2xi64> // CHECK-SAME: window_strides = dense<[4, 5]> : tensor<2xi64> // CHECK-SAME: } - // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<256x30x30x16xf32> + // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<16x30x30x256xf32> - %convolution.4 = f32[256,30,30,16]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} // CHECK-NEXT: "xla_hlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple> diff --git a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h index d8b4c2554bb..ced5769b44c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h +++ b/tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_BUFFER_ASSIGNMENT_H_ -#include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Liveness.h" -#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Builders.h" // TF:llvm-project +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" // TF:llvm-project #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" // TF:llvm-project diff --git a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc index a20511a95fc..0c9585a817f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc @@ -33,24 +33,23 @@ namespace { // Converts binary ops that statically are determined to not broadcast directly // to the corresponding xla_hlo non-broadcasting op. template -struct ConvertTrivialNonBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { +struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { // Only rewrite for statically determinable non-broadcasting cases. - auto lhs = operands[0].getType().dyn_cast(); - auto rhs = operands[1].getType().dyn_cast(); - if (!lhs || !rhs) return failure(); + auto lhs_type = op.lhs().getType().template dyn_cast(); + auto rhs_type = op.rhs().getType().template dyn_cast(); + if (!lhs_type || !rhs_type) return failure(); // Requires rank broadcast. - if (lhs.getRank() != rhs.getRank()) return failure(); + if (lhs_type.getRank() != rhs_type.getRank()) return failure(); // Any dynamic dimension may require broadcasting and requires more // analysis. - if (!lhs.hasStaticShape() || !rhs.hasStaticShape()) return failure(); + if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) + return failure(); - for (auto extents : llvm::zip(lhs.getShape(), rhs.getShape())) { + for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) { auto lhs_extent = std::get<0>(extents); auto rhs_extent = std::get<1>(extents); if (lhs_extent != rhs_extent) { @@ -58,9 +57,8 @@ struct ConvertTrivialNonBroadcastBinaryOp } } - rewriter.replaceOp( - op, {Adaptor::CreateOp(op, op.getResult().getType(), operands[0], - operands[1], rewriter)}); + rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), + op.lhs(), op.rhs(), rewriter)}); return success(); } }; @@ -83,14 +81,13 @@ struct ConvertTrivialNonBroadcastBinaryOp // Whether that is of any practical benefit remains to be seen. template struct ConvertRankedDynamicBroadcastBinaryOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ChloOpTy op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { // Only support ranked operands. - Value lhs = operands[0]; - Value rhs = operands[1]; + Value lhs = op.lhs(); + Value rhs = op.rhs(); auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); auto result_type = diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index aa29241048b..10f35768bbd 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" @@ -39,16 +40,11 @@ namespace xla_hlo { namespace { constexpr StringRef kTempBufferAttr = "temp"; - -/// Returns DeallocOp to ensure that CopyOp is not inserted after dealloc. -Operation* FindInsertionPointForCopy(Value value) { - for (const auto& user : value.getUsers()) { - if (auto dealloc = dyn_cast(user)) { - return user; - } - } - return nullptr; -} +template +using BaseOpConversion = BufferAssignmentOpConversionPattern; +using StdReturnOpConverter = + NonVoidToVoidReturnOpConverter; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -92,8 +88,9 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, return alloc; } -Value InsertAllocAndDealloc(Location loc, Value result, - ConversionPatternRewriter* rewriter) { +Value InsertAlloc(Location loc, OpResult result, + BufferAssignmentPlacer* bufferAssignment, + ConversionPatternRewriter* rewriter) { auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { result.getDefiningOp()->emitOpError() @@ -101,31 +98,21 @@ Value InsertAllocAndDealloc(Location loc, Value result, } auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); - - Operation* op = result.getDefiningOp(); - auto block = op->getBlock(); - - OpBuilder allocBuilder(op); - allocBuilder.setInsertionPointToStart(block); // Inserting at the beginning - auto alloc = allocBuilder.create(loc, memref_type); - - alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true)); - - allocBuilder.setInsertionPoint(block, std::prev(block->end())); - allocBuilder.create(loc, alloc); - + OpBuilder::InsertionGuard guard(*rewriter); + rewriter->restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result)); + auto alloc = rewriter->create(loc, memref_type); return alloc; } template -class HloToLhloOpConverter : public ConversionPattern { +class HloToLhloOpConverter : public BaseOpConversion { public: - explicit HloToLhloOpConverter(MLIRContext* context) - : ConversionPattern(HloOpTy::getOperationName(), 1, context) {} - + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + HloOpTy hloOp, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { + Operation* op = hloOp.getOperation(); const auto& original_results = op->getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : llvm::enumerate(original_results)) { @@ -135,8 +122,8 @@ class HloToLhloOpConverter : public ConversionPattern { return failure(); } if (resultType.hasStaticShape()) { - buffer_args.push_back( - InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter)); + buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(), + this->bufferAssignment, &rewriter)); } else { SmallVector results_shape; auto shape_type_op = dyn_cast(op); @@ -156,9 +143,9 @@ class HloToLhloOpConverter : public ConversionPattern { }; struct HloToLhloDynamicBroadcastInDimOpConverter - : public OpConversionPattern { + : public BaseOpConversion { public: - using OpConversionPattern::OpConversionPattern; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, @@ -175,10 +162,9 @@ struct HloToLhloDynamicBroadcastInDimOpConverter } }; -struct HloToLhloReduceOpConverter - : public OpConversionPattern { +struct HloToLhloReduceOpConverter : public BaseOpConversion { public: - using OpConversionPattern::OpConversionPattern; + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( xla_hlo::ReduceOp op, ArrayRef operands, @@ -194,7 +180,8 @@ struct HloToLhloReduceOpConverter const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); for (auto result : original_results) { - buffer_args.push_back(InsertAllocAndDealloc(loc, result, &rewriter)); + buffer_args.push_back( + InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); } auto new_op = rewriter.create( loc, llvm::None, buffer_args, op.getAttrs()); @@ -230,12 +217,12 @@ struct HloToLhloReduceOpConverter } }; -class HloToLhloTensorLoadOpConverter : public ConversionPattern { +class HloToLhloTensorLoadOpConverter + : public BaseOpConversion { public: - explicit HloToLhloTensorLoadOpConverter(MLIRContext* context) - : ConversionPattern(TensorLoadOp::getOperationName(), 1, context) {} + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + mlir::TensorLoadOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOp(op, operands); return success(); @@ -243,13 +230,13 @@ class HloToLhloTensorLoadOpConverter : public ConversionPattern { }; // TODO(b/137624192): Rewrite into a copy and elide copy if possible. -class HloToLhloTensorStoreOpConverter : public ConversionPattern { +class HloToLhloTensorStoreOpConverter + : public BaseOpConversion { public: - explicit HloToLhloTensorStoreOpConverter(MLIRContext* context) - : ConversionPattern(TensorStoreOp::getOperationName(), 1, context) {} + using BaseOpConversion::BaseOpConversion; LogicalResult matchAndRewrite( - Operation* op, ArrayRef operands, + mlir::TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOpWithNewOp( op, llvm::None, operands.front(), operands.back()); @@ -291,7 +278,6 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // "xla_lhlo.multiply"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// dealloc %0 : memref<2x2xf32> // "xla_lhlo.terminator"() : () -> () // }) : () -> () // return @@ -313,14 +299,13 @@ class HloToLhloTensorStoreOpConverter : public ConversionPattern { // %arg1: memref<4xf32>, // %arg2: memref<4xf32>) { // %0 = alloc() : memref<4xf32> -// %1 = alloc() : memref<4xf32> + // "xla_lhlo.maximum"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () +// %1 = alloc() : memref<4xf32> // "xla_lhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () -// dealloc %0 : memref<4xf32> -// dealloc %1 : memref<4xf32> // "xla_lhlo.terminator"() : () -> () // } @@ -346,101 +331,25 @@ struct HloLegalizeToLhlo }); auto module = getOperation(); - populateHLOToLHLOConversionPattern(module.getContext(), &patterns); - - // Do partial conversion so we can have unknown ops in tests. - if (failed(applyPartialConversion(module, target, patterns, nullptr))) { - signalPassFailure(); - } + BufferAssignmentTypeConverter converter; + module.walk([&](FuncOp func) { + BufferAssignmentPlacer bufferAssignment(func); + OwningRewritePatternList patterns; + populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, + &converter, &patterns); + return WalkResult( + applyPartialConversion(func, target, patterns, &converter)); + }); } }; - -Type ConvertType(Type t) { - if (auto tensorType = t.dyn_cast()) { - return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - } - return t; -} - } // namespace -/// Transforms FuncOp arguments and results from tensors to buffers. Tensor -/// results are converted to memrefs and appended to the argument list. -class HloToLhloFuncOpConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - if (funcOp.getBody().getBlocks().size() > 1) { - funcOp.emitOpError() << "tensor to buffer conversion expects a single " - "block in the region containing the operation"; - return failure(); - } - - auto funcType = funcOp.getType(); - - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) { - conversion.addInputs(argType.index(), ConvertType(argType.value())); - } - for (auto resType : funcType.getResults()) { - conversion.addInputs(ConvertType(resType)); - } - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType( - rewriter.getFunctionType(conversion.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&funcOp.getBody(), conversion); - }); - return success(); - } -}; - -/// Transforms ReturnOp to LhloTerminator. CopyOp is inserted to copy each -/// result to the corresponding buffer argument. -class StdToLhloReturnOpConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::ReturnOp returnOp, ArrayRef operands, - ConversionPatternRewriter& rewriter) const final { - auto numReturnValues = returnOp.getNumOperands(); - auto funcOp = returnOp.getParentOfType(); - auto numFuncArgs = funcOp.getNumArguments(); - auto loc = returnOp.getLoc(); - - for (auto operand : llvm::enumerate(operands)) { - auto returnArgNumber = numFuncArgs - numReturnValues + operand.index(); - auto dstBuffer = funcOp.getArgument(returnArgNumber); - if (dstBuffer == operand.value()) { - continue; - } - - auto dealloc = FindInsertionPointForCopy(operand.value()); - - if (dealloc == nullptr) { - returnOp.emitOpError() - << "Missing dealloc for operand " << operand.index(); - return failure(); - } - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(dealloc); - rewriter.create(loc, llvm::None, operand.value(), - funcOp.getArgument(returnArgNumber)); - } - rewriter.replaceOpWithNewOp(returnOp); - return success(); - } -}; - -void populateHLOToLHLOConversionPattern(MLIRContext* context, - OwningRewritePatternList* patterns) { +void populateHLOToLHLOConversionPattern( + MLIRContext* context, BufferAssignmentPlacer* bufferAssignment, + TypeConverter* converter, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, - HloToLhloFuncOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -472,8 +381,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloReduceOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter, - StdToLhloReturnOpConverter - >(context); + FunctionAndBlockSignatureConverter, + StdReturnOpConverter + >(context, bufferAssignment, converter); // clang-format on } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index a6a6829b109..10bac232b0f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project @@ -43,9 +44,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/convert_op_folder.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -2589,6 +2592,21 @@ class ConvertRangeOp : public OpRewritePattern { } }; +ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { + auto int_attr = attr.cast(); + auto type = val.getType().cast(); + + SmallVector axis; + axis.reserve(int_attr.getNumElements()); + + int64_t rank = type.getRank(); + for (auto val : int_attr.getValues()) { + axis.push_back((val.getSExtValue() + rank) % rank); + } + + return builder->getI64TensorAttr(axis); +} + /// Converts the LinSpace tensorflow op to a xla_hlo.iota op with a scaling /// and offset applied to generate the linspace values. The output tensor needs /// to have a static shape. The implementation is defined in C++ because there @@ -4181,6 +4199,68 @@ class ConvertXlaShardingOp : public OpRewritePattern { } }; +// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO. +class ConvertInplaceUpdateOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, + PatternRewriter &rewriter) const override { + auto input = op.x(); + auto indices = op.i(); + auto updates = op.v(); + + // Slice each row of `i` and `v` to perform a separate dynamic-update-slice + // on the contents of `x`. + auto input_type = input.getType().cast(); + auto updates_type = updates.getType().cast(); + auto indices_type = indices.getType().cast(); + if (!indices_type.hasStaticShape()) return failure(); + + if (indices_type.getRank() != 1) return failure(); + + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), + RankedTensorType::get({}, indices_type.getElementType())); + auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(64), 0); + auto unpacked_indices = rewriter.create( + op.getLoc(), unpacked_indices_type, indices, zero_attr); + + SmallVector split_updates_shape; + split_updates_shape.append(updates_type.getShape().begin(), + updates_type.getShape().end()); + split_updates_shape.front() = 1; + SmallVector split_updates_type; + split_updates_type.resize( + updates_type.getShape().front(), + RankedTensorType::get(split_updates_shape, + updates_type.getElementType())); + + auto cst = + rewriter.create(op.getLoc(), zero_attr).getResult(); + auto split_updates = rewriter.create( + op.getLoc(), split_updates_type, cst, updates); + + SmallVector input_indices; + input_indices.resize(input_type.getRank(), cst); + + SmallVector starts(updates_type.getRank(), 0); + SmallVector strides(updates_type.getRank(), 1); + SmallVector limits(updates_type.getShape().begin(), + updates_type.getShape().end()); + + for (auto pair : + llvm::zip(unpacked_indices.output(), split_updates.output())) { + input_indices.front() = std::get<0>(pair); + input = rewriter.create( + op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO. class ConvertXlaDynamicUpdateSliceOp : public OpRewritePattern { @@ -4785,6 +4865,62 @@ class ConvertQrOp : public OpRewritePattern { } }; +// Emits debug information which includes the number of ops of each type which +// failed to legalize. +void EmitLegalizationErrors(Operation *op, + const DenseSet &nonlegalized_ops) { + // Track the legalization failures by mapping op name to information about + // that failure: the number of unlegalized occurances of the op, and one + // example operation that failed. + std::map> op_name_to_error_info; + DenseSet error_ops; + for (Operation *nonlegalized_op : nonlegalized_ops) { + // Increment count of this legalization failure. + StringRef op_name = nonlegalized_op->getName().getStringRef(); + // If this emplace is successful, it's the first time we've encountered + // this op type. Initialize count to 0 so that after increment, it is 1. + auto insertion_result = op_name_to_error_info.emplace( + op_name, std::make_pair(0, nonlegalized_op)); + ++insertion_result.first->second.first; + } + std::vector error_messages; + error_messages.reserve(op_name_to_error_info.size()); + for (const auto &op_info : op_name_to_error_info) { + error_messages.push_back( + llvm::formatv("{0} (count: {1})", op_info.first, op_info.second.first)); + } + Location loc = op->getLoc(); + emitError(loc) << "The following operations cannot be legalized: " + << llvm::join(error_messages, "; ") + << ". These legalization failure(s) may be due to missing TF " + "to HLO lowerings and/or unsupported attributes, etc."; + // Emit more information about the missing ops. This error message + // contains useful details beyond the op name (input and output shapes, + // attributes, etc.). + if (!VLOG_IS_ON(1) && nonlegalized_ops.size() != 1) { + emitError(loc) + << "Emitting more detail about one op that failed to legalize..."; + } else if (VLOG_IS_ON(1)) { + emitError(loc) << "Emitting more detail about one of each type of op " + "that failed to legalize..."; + } + for (const auto &op_info : op_name_to_error_info) { + op_info.second.second->emitOpError() << "is not legalizable"; + if (!VLOG_IS_ON(1)) break; + } +} + +// Performs the lowering to XLA dialect. +void LegalizeTF::runOnFunction() { + if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) + signalPassFailure(); +} + +static PassRegistration pass( + "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); + +} // end namespace + #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { @@ -4806,12 +4942,13 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, ConvertLinSpaceOp, - ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp, - ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, - ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, - ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op, - ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp, + ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, + ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, + ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, + ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, + ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp, + ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, @@ -4820,7 +4957,12 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp, ConvertXlaDynamicUpdateSliceOp>(op->getContext()); + // Populate with CHLO->HLO lowerings to account for TF ops legalized to + // CHLO first. + xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + ConversionTarget target(*context); + target.addIllegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); @@ -4830,23 +4972,21 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { if (!allow_partial_conversion) { // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. target.addLegalOp(); - return applyFullConversion(op, target, patterns); + DenseSet nonlegalized_ops; + LogicalResult result = applyPartialConversion( + op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops); + // In order to enforce that the conversion result is fully converted, + // fail if there are any nonlegalized ops in the set. + if (failed(result) || !nonlegalized_ops.empty()) { + EmitLegalizationErrors(op, nonlegalized_ops); + return failure(); + } + return result; } return applyPartialConversion(op, target, patterns); } -/// Performs the lowering to XLA dialect. -void LegalizeTF::runOnFunction() { - if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) - signalPassFailure(); -} - -static PassRegistration pass( - "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); - -} // end namespace - std::unique_ptr> createLegalizeTFPass( bool allow_partial_conversion) { return std::make_unique(allow_partial_conversion); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index b2a7c1e7f62..959902692dc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -18,6 +18,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; @@ -80,6 +81,9 @@ def BiasAddFeatureDimension : NativeCodeCall< // $input needs to be a ranked tensor to identify index of the feature // dimension depending on the data_format 'NHWC' or 'NCHW'. +// TODO(laurenzo): This should be converted to do explicit broadcasting since +// it can generate broadcast dimensions that are not compatible with the simple +// xla_chlo.add broadcast_dims. def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format), (HLO_AddOp $input, $bias, (BiasAddFeatureDimension $data_format, $input))>; @@ -96,16 +100,16 @@ class DirectBinaryPat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; -foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], - [TF_AddV2Op, HLO_AddOp], - [TF_DivOp, HLO_DivOp], - [TF_LeftShiftOp, HLO_ShiftLeftOp], - [TF_MaximumOp, HLO_MaxOp], - [TF_MinimumOp, HLO_MinOp], - [TF_MulOp, HLO_MulOp], - [TF_PowOp, HLO_PowOp], - [TF_RealDivOp, HLO_DivOp], - [TF_SubOp, HLO_SubOp]] in +foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], + [TF_AddV2Op, HLOClient_BroadcastAddOp], + [TF_DivOp, HLOClient_BroadcastDivOp], + [TF_LeftShiftOp, HLOClient_BroadcastShiftLeftOp], + [TF_MaximumOp, HLOClient_BroadcastMaxOp], + [TF_MinimumOp, HLOClient_BroadcastMinOp], + [TF_MulOp, HLOClient_BroadcastMulOp], + [TF_PowOp, HLOClient_BroadcastPowOp], + [TF_RealDivOp, HLOClient_BroadcastDivOp], + [TF_SubOp, HLOClient_BroadcastSubOp]] in def : DirectBinaryPat; def LowerRightShiftSigned : @@ -196,10 +200,10 @@ class DirectLogicalBinaryPat (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $l)]>; -foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], - [TF_LogicalOrOp, HLO_OrOp], - [TF_BitwiseOrOp, HLO_OrOp], - [TF_BitwiseAndOp, HLO_AndOp]] in +foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp], + [TF_LogicalOrOp, HLOClient_BroadcastOrOp], + [TF_BitwiseOrOp, HLOClient_BroadcastOrOp], + [TF_BitwiseAndOp, HLOClient_BroadcastAndOp]] in def : DirectLogicalBinaryPat; //===----------------------------------------------------------------------===// @@ -208,7 +212,8 @@ foreach fromToBinPair = [[TF_LogicalAndOp, HLO_AndOp], class DirectComparePat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction)>; + (HLOClient_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction)>; def : DirectComparePat; def : DirectComparePat; @@ -218,7 +223,8 @@ def : DirectComparePat; class EqualityPat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r, TrueBoolAttr:$incompatible_shape_error), - (HLO_CompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction), + (HLOClient_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction), [(AreBroadcastCompatible $l, $r)]>; def : EqualityPat; @@ -273,6 +279,13 @@ def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), (HLO_CrossReplicaSumOp $input, (CastElementsToI64Elements $group_assignment))>; +//===----------------------------------------------------------------------===// +// All2All op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (TF_ConstOp $group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), + (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; + //===----------------------------------------------------------------------===// // FFT op patterns. //===----------------------------------------------------------------------===// @@ -513,6 +526,16 @@ foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { [(ArgTypesMatchCallee $op, $args, $f)]>; } +//===----------------------------------------------------------------------===// +// Reverse op patterns. +//===----------------------------------------------------------------------===// + +// Handles axis conversion for TF reverse. +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">; + +def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), + (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; + //===----------------------------------------------------------------------===// // Ternary op patterns. //===----------------------------------------------------------------------===// @@ -543,7 +566,6 @@ foreach Mapping = [ [TF_LogicalNotOp, HLO_NotOp], [TF_NegOp, HLO_NegOp], [TF_RealOp, HLO_RealOp], - [TF_RoundOp, HLO_RoundOp], [TF_RsqrtOp, HLO_RsqrtOp], [TF_SinOp, HLO_SinOp], [TF_SqrtOp, HLO_SqrtOp], diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 25bdd0f5f62..76657bd5e20 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Optional.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" @@ -81,28 +83,51 @@ static bool IsOpWhitelisted(Operation* op) { // clang-format off static llvm::SmallDenseSet ops = { TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), - TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -111,18 +136,43 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get() + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get() }; // clang-format on @@ -170,6 +220,10 @@ class FuncLegalizer { // legalization. LogicalResult LegalizeOp(Operation* op); + // Converts the given operand to expression of kind kConstant or kXlaOp. + // Emits a remark and returns expression of kind kInvalid on failure. + tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op); + FuncOp func_; std::string device_type_; @@ -296,6 +350,17 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { // Transfer ownership of the kernel to a local smart pointer. auto op_kernel = absl::WrapUnique(op_kernel_raw); + std::vector required_constants; + status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( + *op_kernel, &required_constants); + if (!status.ok()) { + op->emitRemark() << "failed to compute required constants: " + << status.ToString(); + return success(); + } + llvm::SmallDenseSet required_consts; + required_consts.insert(required_constants.begin(), required_constants.end()); + // TensorValue in inputs are backed by tensors which in turn depend on // expressions. So, pre-allocate them to the required size. InlinedVector expressions; @@ -306,45 +371,39 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { inputs.reserve(op->getNumOperands()); // Prepare the list of Tensor inputs for the kernel. - for (Value operand : op->getOperands()) { - // Skip this op if XLA doesn't support this operand type. - auto xla_op_or = hlo_builder_.MakeXlaOp(operand); - if (!xla_op_or.ok()) { - op->emitRemark() << "skipping legalization due to " - << xla_op_or.status().ToString(); + for (auto it : llvm::enumerate(op->getOperands())) { + Value operand = it.value(); + size_t idx = it.index(); + + tensorflow::XlaExpression expr = GetExprForOperand(operand, op); + tensorflow::XlaExpression::Kind kind = expr.kind(); + if (kind == tensorflow::XlaExpression::Kind::kInvalid) return success(); + if (required_consts.count(idx) && + kind != tensorflow::XlaExpression::Kind::kConstant) { + op->emitRemark() << "lowering requires operand #" << idx + << " to be a constant"; return success(); } - ::xla::XlaOp xla_op = xla_op_or.ValueOrDie(); + expressions.push_back(expr); - tensorflow::DataType dtype; - status = tensorflow::ConvertToDataType(operand.getType(), &dtype); - if (!status.ok()) { - op->emitRemark() << "skipping legalization due to " << status.ToString(); - return success(); - } - - auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype); - expressions.push_back(expression); - - if (!tensorflow::DataTypeCanUseMemcpy(dtype)) { + if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) { op->emitRemark() << "skipping legalization due to unsupported type " << operand.getType(); return success(); } - auto shape_or = expression.GetShape(); + auto shape_or = expr.GetShape(); if (!shape_or.ok()) { op->emitRemark() << "failed to get shape for expression. " - << expression.HumanString(); + << expr.HumanString(); return success(); } tensors.emplace_back( - device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype, + device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(), shape_or.ValueOrDie()); tensorflow::Tensor& tensor = tensors.back(); - tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression, - &tensor); + tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor); inputs.emplace_back(&tensor); } @@ -376,13 +435,51 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { return op->emitError( "expects XlaExpression of kind kXlaOp in compiled output"); auto value = hlo_builder_.GetValue(expr->handle()); - op->getResult(i).replaceAllUsesWith(value); + mlir::OpResult old_result = op->getResult(i); + if (value.getType() != old_result.getType()) { + value = + hlo_builder_.create(value, old_result.getType()); + } + old_result.replaceAllUsesWith(value); } op->erase(); return success(); } +tensorflow::XlaExpression FuncLegalizer::GetExprForOperand(Value operand, + Operation* op) { + ElementsAttr const_attr; + auto defining_op = operand.getDefiningOp(); + if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) { + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(const_attr, &tensor); + if (!status.ok()) { + op->emitRemark() << "skipping legalization due to failed const conversion" + << status.ToString(); + return tensorflow::XlaExpression::Invalid(); + } + return tensorflow::XlaExpression::Constant(tensor); + } + + // Skip this op if XLA doesn't support this operand type. + auto xla_op_or = hlo_builder_.MakeXlaOp(operand); + if (!xla_op_or.ok()) { + op->emitRemark() << "skipping legalization due to " + << xla_op_or.status().ToString(); + return tensorflow::XlaExpression::Invalid(); + } + ::xla::XlaOp xla_op = xla_op_or.ValueOrDie(); + + tensorflow::DataType dtype; + auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype); + if (!status.ok()) { + op->emitRemark() << "skipping legalization due to " << status.ToString(); + return tensorflow::XlaExpression::Invalid(); + } + return tensorflow::XlaExpression::XlaOp(xla_op, dtype); +} + class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index bdee1b77cff..43c0911a4a6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -19,7 +19,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index e6f3ac02d4f..f0eb3cc1a0f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project @@ -112,7 +112,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { auto step = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); - auto loop = rewriter.create(loc, zero, upper, step); + auto loop = rewriter.create(loc, zero, upper, step); rewriter.setInsertionPointToStart(loop.getBody()); // Compute memrefs for the value to reduce. This makes it easier to just @@ -173,8 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>(); target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index 54b3acd3787..734a75a4307 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -61,15 +61,15 @@ Value ApplySingleResultLhloCode(Location loc, ValueRange operands, // Converts a block with LHLO ops and with signature: // ^bb(%lhs: memref, %rhs: memref, %res: memref): -// into a reduction operator of loop.reduce by doing buffer allocation for -// scalar arguments and the result of `loop.reduce` to make it compatible with +// into a reduction operator of scf.reduce by doing buffer allocation for +// scalar arguments and the result of `scf.reduce` to make it compatible with // LHLO ops. -void ConvertToReductionOperator(Location loc, loop::ReduceOp reduce_op, +void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op, Block* lhlo_block, OpBuilder* b) { Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); OpBuilder::InsertionGuard guard(*b); b->setInsertionPointToStart(&loop_reduce_op_body); - b->create( + b->create( loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(), lhlo_block, b)); } @@ -136,9 +136,9 @@ MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, return mapped_ivs; } -// Returns loop::Parallel over a shaped value with static or dynamic shape. -loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, - OpBuilder* b) { +// Returns scf::Parallel over a shaped value with static or dynamic shape. +scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, + OpBuilder* b) { Value zero = b->create(loc, 0); Value one = b->create(loc, 1); @@ -151,10 +151,10 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, lower.push_back(zero); step.push_back(one); } - return b->create(loc, lower, upper, step); + return b->create(loc, lower, upper, step); } -// Converts `xla_lhlo.ReduceOp` into two loop::ParallelOp and a loop::ReduceOp. +// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // contains the reduction operator. @@ -170,10 +170,10 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // is roughly converted into: // // %init = load %init_buf[] : memref -// loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { -// %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { +// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { +// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> -// loop.reduce(%elem_to_reduce) { +// scf.reduce(%elem_to_reduce) { // ^bb0(%elem: f32, %acc: f32): // no predecessors // elem_buf = alloc() : memref // store %elem, elem_buf[] : memref @@ -181,11 +181,11 @@ loop::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // store %acc, acc_buf[] : memref // // %acc_result = load acc_buf[] : memref -// loop.reduce.return %acc_result : f32 +// scf.reduce.return %acc_result : f32 // } : f32 -// loop.yield +// scf.yield // } : f32 -// loop.yield +// scf.yield // } class ReduceOpConverter : public OpConversionPattern { public: @@ -197,7 +197,7 @@ class ReduceOpConverter : public OpConversionPattern { // TODO(b/137624192) Implement variadic reduce. if (xla_reduce_op.out().size() != 1) return failure(); - loop::ReduceOp reduce_op = + scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, &xla_reduce_op.body().front(), &rewriter); @@ -206,26 +206,26 @@ class ReduceOpConverter : public OpConversionPattern { } private: - // Creates nested `loop.parallel` ops with `loop.reduce`. The outer ParallelOp + // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp // refers to the parallel dimensions of `xla_reduce_op` if any and the inner - // ParallelOp refers to the reduction dimensions. The loop.reduce op is + // ParallelOp refers to the reduction dimensions. The scf.reduce op is // returned. // // If the reduction argument is a memref<100x10x5xf32> and the // reduction is performed along dimension 1 then this method will generate // // %init = load %init_buf[] : memref - // loop.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { - // %result = loop.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { + // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { + // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> - // loop.reduce(%elem_to_reduce) { + // scf.reduce(%elem_to_reduce) { // // } : f32 - // loop.yield + // scf.yield // } : f32 - // loop.yield + // scf.yield // } - loop::ReduceOp CreateReduceOpInNestedParallelLoops( + scf::ReduceOp CreateReduceOpInNestedParallelLoops( xla_lhlo::ReduceOp xla_reduce_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); @@ -254,13 +254,13 @@ class ReduceOpConverter : public OpConversionPattern { SmallVector init_value = { rewriter->create(loc, *xla_reduce_op.init_values().begin())}; // Outer ParallelOp is not needed if it is a reduction across all dims. - loop::ParallelOp outer; + scf::ParallelOp outer; if (!parallel_lower.empty()) { - outer = rewriter->create(loc, parallel_lower, - parallel_upper, parallel_step); + outer = rewriter->create(loc, parallel_lower, + parallel_upper, parallel_step); rewriter->setInsertionPointToStart(outer.getBody()); } - loop::ParallelOp inner = rewriter->create( + scf::ParallelOp inner = rewriter->create( loc, reduce_lower, reduce_upper, reduce_step, init_value); Value reduction_result = *inner.getResults().begin(); @@ -294,7 +294,7 @@ class ReduceOpConverter : public OpConversionPattern { rewriter->setInsertionPointToStart(inner.getBody()); Value elem = rewriter->create( loc, *xla_reduce_op.operands().begin(), indices); - return rewriter->create(loc, elem); + return rewriter->create(loc, elem); } }; @@ -314,8 +314,8 @@ class ReduceOpConverter : public OpConversionPattern { // accumulator = reduction_operator(output[O], value) // output[O] = accumulator // -// Converts `xla_lhlo.ReduceWindowOp` into two loop::ParallelOp and a -// loop::ReduceOp. +// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a +// scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops that traverese output // buffer. The inner `ParalleOp` refers to the reduction loops that traverse // reduction windows and `ReduceOp` contains the reduction operator. @@ -341,20 +341,20 @@ class ReduceOpConverter : public OpConversionPattern { // is roughly converted into: // // %neutral_elem = load %init_buf[] : memref -// loop.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { -// %result = loop.parallel (%iw, %jw) = (%c0, %c0) +// scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { +// %result = scf.parallel (%iw, %jw) = (%c0, %c0) // to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 { // %in_bounds = // %elem = load %operand[%computed_i, %computed_j] // %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32 -// loop.reduce(%elem_to_reduce) : f32 { +// scf.reduce(%elem_to_reduce) : f32 { // ^bb0(%arg7: f32, %arg8: f32): // // } -// loop.yield +// scf.yield // } // store %result, %output_buffer[%i, %j] : memref<56x56xf32> -// loop.yield +// scf.yield // } // return // } @@ -366,12 +366,12 @@ class ReduceWindowOpConverter LogicalResult matchAndRewrite( xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { - loop::ParallelOp output_loop, window_loop; + scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, &rewriter); - loop::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( + scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( xla_reduce_window_op, output_loop, window_loop, &rewriter); ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, @@ -381,7 +381,7 @@ class ReduceWindowOpConverter } private: - std::pair + std::pair CreateParallelLoopsToTraverseOutputAndWindow( xla_lhlo::ReduceWindowOp xla_reduce_window_op, ConversionPatternRewriter* rewriter) const { @@ -405,7 +405,7 @@ class ReduceWindowOpConverter window_upper.push_back( rewriter->create(loc, window_dim.getSExtValue())); } - auto window_loop = rewriter->create( + auto window_loop = rewriter->create( loc, window_lower, window_upper, window_step, init_value); Value reduction_result = *window_loop.getResults().begin(); @@ -414,9 +414,9 @@ class ReduceWindowOpConverter return std::make_pair(output_loop, window_loop); } - loop::ReduceOp CreateReduceOpInNestedParallelLoops( + scf::ReduceOp CreateReduceOpInNestedParallelLoops( xla_lhlo::ReduceWindowOp xla_reduce_window_op, - loop::ParallelOp output_loop, loop::ParallelOp window_loop, + scf::ParallelOp output_loop, scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { rewriter->setInsertionPointToStart(window_loop.getBody()); auto loc = xla_reduce_window_op.getLoc(); @@ -436,20 +436,20 @@ class ReduceWindowOpConverter xla_reduce_window_op, output_loop.getInductionVars(), window_loop.getInductionVars(), rewriter); - auto elem_or_init = rewriter->create( + auto elem_or_init = rewriter->create( loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, /*withElseRegion=*/true); OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); Value elem = then_builder.create( loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); - then_builder.create(loc, elem); + then_builder.create(loc, elem); OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); - else_builder.create(loc, *window_loop.initVals().begin()); + else_builder.create(loc, *window_loop.initVals().begin()); - return rewriter->create(loc, - *elem_or_init.results().begin()); + return rewriter->create(loc, + *elem_or_init.results().begin()); } }; @@ -457,16 +457,16 @@ class ReduceWindowOpConverter // https://www.tensorflow.org/xla/operation_semantics#selectandscatter // // Pseudocode: -// loop.parallel(coordinates O in the output): +// scf.parallel(coordinates O in the output): // output[O] = init -// loop.parallel(coordinates S in the source): +// scf.parallel(coordinates S in the source): // selected_ivs = 0 // selected_val = 0 // initialized_flag = false -// loop.for (first dim W_1 in the window) +// scf.for (first dim W_1 in the window) // iter_args (selected_ivs, selected_val, initialized_flag): // ... -// loop.for (last dim W_N in the window): +// scf.for (last dim W_N in the window): // iter_args (selected_ivs, selected_val, initialized_flag): // I = S * stride + W - pad_low // if I within bounds of operand: @@ -490,7 +490,7 @@ class SelectAndScatterOpConverter ConversionPatternRewriter& rewriter) const final { auto loc = s_and_s_op.getLoc(); InitializeOutput(s_and_s_op, &rewriter); - loop::ParallelOp loop_over_src = + scf::ParallelOp loop_over_src = MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter); rewriter.setInsertionPointToStart(loop_over_src.getBody()); @@ -520,7 +520,7 @@ class SelectAndScatterOpConverter auto loc = s_and_s_op.getLoc(); Value init_value = b->create(loc, s_and_s_op.init_value()); - loop::ParallelOp loop_over_output = + scf::ParallelOp loop_over_output = MakeLoopOverShape(loc, s_and_s_op.out(), b); OpBuilder::InsertionGuard guard(*b); b->setInsertionPointToStart(loop_over_output.getBody()); @@ -531,10 +531,10 @@ class SelectAndScatterOpConverter struct WindowLoops { SmallVector selected_ivs; SmallVector window_ivs; - loop::ForOp inner_loop; + scf::ForOp inner_loop; }; WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, - loop::ParallelOp loop_over_src, + scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value zero = b->create(loc, 0); @@ -558,12 +558,12 @@ class SelectAndScatterOpConverter s_and_s_op.window_dimensions()->getIntValues()) { Value upper = b->create(loc, window_dim.getSExtValue()); result.inner_loop = - b->create(loc, zero, upper, one, iter_args); + b->create(loc, zero, upper, one, iter_args); if (b->getInsertionBlock() == loop_over_src.getBody()) { ip = b->saveInsertionPoint(); result.selected_ivs = result.inner_loop.getResults().take_front(rank); } else { - b->create(loc, result.inner_loop.getResults()); + b->create(loc, result.inner_loop.getResults()); } b->setInsertionPointToStart(result.inner_loop.getBody()); iter_args = ValueRange{result.inner_loop.getRegionIterArgs()}; @@ -599,7 +599,7 @@ class SelectAndScatterOpConverter }; SmallVector SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, - loop::ParallelOp loop_over_src, + scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -614,7 +614,7 @@ class SelectAndScatterOpConverter IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs()); - auto if_in_bounds = inner_loop_b.create( + auto if_in_bounds = inner_loop_b.create( loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds, /*withElseRegion=*/true); @@ -623,16 +623,16 @@ class SelectAndScatterOpConverter OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder(); auto select_or_init_results = SelectOrInitialize( s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b); - in_bounds_then_b.create(loc, select_or_init_results); + in_bounds_then_b.create(loc, select_or_init_results); } // Case when we are in the pad. { OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder(); - in_bounds_else_b.create(loc, ivs_val_flag.to_vector()); + in_bounds_else_b.create(loc, ivs_val_flag.to_vector()); } - inner_loop_b.create(loc, if_in_bounds.getResults()); + inner_loop_b.create(loc, if_in_bounds.getResults()); return window_loops.selected_ivs; } @@ -647,8 +647,8 @@ class SelectAndScatterOpConverter Value operand_elem = b->create(loc, s_and_s_op.operand(), operand_ivs); auto if_init = - b->create(loc, iter_arg_types, ivs_val_flag->is_init(), - /*withElseRegion=*/true); + b->create(loc, iter_arg_types, ivs_val_flag->is_init(), + /*withElseRegion=*/true); // Init == true, i.e. iter args are already initialized with a selected // element in boundaries of the operand. Select function has to be computed // here. @@ -660,32 +660,31 @@ class SelectAndScatterOpConverter ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()}, &lhlo_select, &if_init_then_b); - auto if_pred = - if_init_then_b.create(loc, iter_arg_types, pred, - /*withElseRegion=*/true); + auto if_pred = if_init_then_b.create(loc, iter_arg_types, pred, + /*withElseRegion=*/true); // Pred == true, therefore pack newly selected ivs, val and init flag back // to iter_args and return. { OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(); - if_pred_then_b.create( + if_pred_then_b.create( loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); } // Pred == false, therefore return old iter_args. { OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(); - if_pred_else_b.create(loc, ivs_val_flag->to_vector()); + if_pred_else_b.create(loc, ivs_val_flag->to_vector()); } - if_init_then_b.create(loc, if_pred.getResults()); + if_init_then_b.create(loc, if_pred.getResults()); } // Init == false, i.e. only pad was visited before and this is the first // element in the boundaries of the operand. { OpBuilder if_init_else_b = if_init.getElseBodyBuilder(); - if_init_else_b.create( + if_init_else_b.create( loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); } return if_init.getResults(); @@ -708,7 +707,7 @@ struct LhloLegalizeToParallelLoops ConversionTarget target(getContext()); target.addLegalDialect(); + scf::SCFDialect, XlaLhloDialect>(); target.addIllegalOp(); diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h index 6178434c8bb..fed21e9bafc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -63,6 +63,7 @@ MAP_HLO_TO_LHLO(RemOp); MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(SelectOp); MAP_HLO_TO_LHLO(SignOp); +MAP_HLO_TO_LHLO(SinOp); MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(TanhOp); diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index 8296011bf54..c317dc36b3c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -227,6 +227,28 @@ inline Value MapLhloOpToStdScalarOp( loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, @@ -259,11 +281,9 @@ inline Value MapLhloOpToStdScalarOp( // No conversion is needed for the same width integers return args.front(); } - // TODO(dfki-ehna): Add other primitive type conversions - // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) { - // return b.create(loc, result_types, - // args,mlir::None); - // } + if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { + return b->create(loc, result_types, args, mlir::None); + } return nullptr; } @@ -275,6 +295,14 @@ inline Value MapLhloOpToStdScalarOp( loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + /// Implements the conversion of XLA op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index a4ffa57957e..bf666400900 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -50,12 +50,6 @@ static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, template bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, Value *out_lhs, Value *out_rhs) { - if (!op.broadcast_dimensions().hasValue()) { - // Note: the op may still have an implicit broadcast on it, such as - // for (tensor<1xf32>, tensor<4xf32>). - return false; - } - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, // replacing the original LHS and RHS args in the source op with the results // of the broadcasts. @@ -79,25 +73,7 @@ bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, auto lhs_rank = lhs_ranked_type.getRank(); auto rhs_rank = rhs_ranked_type.getRank(); - - // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. - // Use the original op.broadcast_dimensions for the lower rank arg. - auto higher_rank_broadcast_dims = - GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); - DenseIntElementsAttr lhs_broadcast_dims; - DenseIntElementsAttr rhs_broadcast_dims; - if (lhs_rank > rhs_rank) { - lhs_broadcast_dims = higher_rank_broadcast_dims; - rhs_broadcast_dims = op.broadcast_dimensions().getValue(); - } else if (lhs_rank < rhs_rank) { - lhs_broadcast_dims = op.broadcast_dimensions().getValue(); - rhs_broadcast_dims = higher_rank_broadcast_dims; - } else { - // This shouldn't happen for legal ops. If the broadcast_dimensions - // attribute is set, the ranks should be different. - // TODO(scotttodd): Add a custom verification for ops and assert here. - return false; - } + ArrayRef op_shape = op_ranked_type.getShape(); // BroadcastInDimOp must have the same element type for operands and results, // so preserve the original output shape and the original input element type. @@ -105,16 +81,32 @@ bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, // broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32> // broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32> // SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - ArrayRef op_shape = op_ranked_type.getShape(); - auto lhs_type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - auto rhs_type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); + if (lhs_ranked_type.getShape() != op_ranked_type.getShape()) { + auto type = + RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); + DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, lhs_rank, rewriter); + if (lhs_rank < rhs_rank) { + attr = op.broadcast_dimensions().getValue(); + } - *out_lhs = rewriter->createOrFold(op.getLoc(), lhs_type, - lhs, lhs_broadcast_dims); - *out_rhs = rewriter->createOrFold(op.getLoc(), rhs_type, - rhs, rhs_broadcast_dims); + lhs = + rewriter->createOrFold(op.getLoc(), type, lhs, attr); + } + + if (rhs_ranked_type.getShape() != op_ranked_type.getShape()) { + auto type = + RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); + DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, rhs_rank, rewriter); + if (rhs_rank < lhs_rank) { + attr = op.broadcast_dimensions().getValue(); + } + + rhs = + rewriter->createOrFold(op.getLoc(), type, rhs, attr); + } + + *out_lhs = lhs; + *out_rhs = rhs; return true; } @@ -359,9 +351,15 @@ struct CompareWithBroadcastConvert : public OpRewritePattern { void SetupMaterializeBroadcastsLegality(MLIRContext *context, ConversionTarget *conversionTarget) { -#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \ - conversionTarget->addDynamicallyLegalOp( \ - [](OpType op) { return !op.broadcast_dimensions().hasValue(); }); +#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \ + conversionTarget->addDynamicallyLegalOp([](OpType op) { \ + if (op.broadcast_dimensions().hasValue()) return false; \ + auto l = op.lhs().getType().cast(); \ + auto r = op.rhs().getType().cast(); \ + if (!l.hasRank() || !r.hasRank()) return false; \ + return l.getShape() == r.getShape(); \ + }); + // Binary elementwise ops. ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp); ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op); diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 2d0164981a3..39375e210d5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -81,8 +81,8 @@ std::unique_ptr> createLegalizeToGpuPass(); // Fuses linalg ops obtained after LHLO lowering. To enable fusion, // operations are first tiled. // -// When 'use_parallel_loops' is set, the tiling will use loop.parallel -// operations. Otherwise, loop.for operations are used. +// When 'use_parallel_loops' is set, the tiling will use scf.parallel +// operations. Otherwise, scf.for operations are used. // // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg // operation has more dimensions than tile sizes provided, 1 is used as diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index ad81cda19b9..9cde6f84474 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { +class BufferAssignmentPlacer; namespace xla_hlo { // Collection of rewrite patterns for lowering a general dot product. @@ -38,9 +39,9 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); // Collection of rewrite patterns for lowering of HLO to LHLO dialect. -void populateHLOToLHLOConversionPattern(MLIRContext *context, - OwningRewritePatternList *patterns); - +void populateHLOToLHLOConversionPattern( + MLIRContext *context, BufferAssignmentPlacer *bufferAssignment, + TypeConverter *converter, OwningRewritePatternList *patterns); // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, OwningRewritePatternList *patterns); diff --git a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc index 8976bd5b7d2..71441656c08 100644 --- a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc @@ -38,7 +38,8 @@ struct InferReturnTypeComponentsPattern : public RewritePattern { SmallVector components; if (failed(defining_op_int.inferReturnTypeComponents( op->getContext(), op->getLoc(), defining_op->getOperands(), - defining_op->getAttrs(), defining_op->getRegions(), components))) { + defining_op->getAttrDictionary(), defining_op->getRegions(), + components))) { return failure(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc index ee75ceac2d1..a12bd9e7c1a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h" + #include #include @@ -72,15 +74,6 @@ StatusOr> HloModuleFromProto( // dialect. class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { public: - // Populate the MLIR `module` with the computation from the `hlo_module` using - // the provided buffer `assignment`. The returned `Status` indicates success - // or failure in the conversion. - static Status EmitModule(const BufferAssignment& assignment, - const HloModule& hlo_module, ModuleOp module) { - return LhloDialectEmitter(assignment, hlo_module, module).Run(); - } - - private: // Main entry point of the processing: after this call the MLIR ModuleOp is // populated with the computation from the HloModule. The returned `Status` // indicates success or failure in the conversion. @@ -94,24 +87,13 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { builder_(module.getContext()), i8_type_(builder_.getIntegerType(8)) {} - Status DefaultAction(HloInstruction* hlo) final { - return ::xla::Unimplemented("unsupported HLO %s", hlo->name()); - } + private: + Status DefaultAction(HloInstruction* instr) final; // Computation parameters don't need any specific handling when they are // visited, they are already processed when we enter a new computation. Status HandleParameter(HloInstruction* instr) final { return Status::OK(); } - // HLO Copy is translated 1:1 to an lhlo.copy operation. - Status HandleCopy(HloInstruction* instr) final { - TF_ASSIGN_OR_RETURN(Value source, GetOrCreateView(instr->operand(0))); - TF_ASSIGN_OR_RETURN(Value dest, GetOrCreateView(instr)); - if (source != dest) - builder_.create(getLocation(instr), - llvm::ArrayRef{}, source, dest); - return Status::OK(); - } - // Helper function to create view in a buffer for a given slice. The view is // cached in the `slices_` map. Value GetOrCreateView(const BufferAllocation::Slice& slice); @@ -160,6 +142,98 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { Type i8_type_; }; +Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { + llvm::SmallVector operands(instr->operand_count() + 1); + for (int arg_idx = 0; arg_idx < instr->operand_count(); ++arg_idx) { + TF_ASSIGN_OR_RETURN(operands[arg_idx], + GetOrCreateView(instr->operand(arg_idx))); + } + + TF_ASSIGN_OR_RETURN(operands.back(), GetOrCreateView(instr)); + Location loc = getLocation(instr); + ArrayRef> attrs; + ArrayRef rets{}; + + using ::xla::HloOpcode; + switch (instr->opcode()) { + case HloOpcode::kAbs: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kAdd: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kAnd: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kCeil: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kComplex: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kCopy: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kCos: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kDivide: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kExp: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kImag: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kLog: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kMaximum: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kMinimum: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kMultiply: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kNegate: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kReal: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kRemainder: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kRsqrt: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSelect: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSign: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSqrt: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kSubtract: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + case HloOpcode::kTanh: + builder_.create(loc, rets, operands, attrs); + return Status::OK(); + default: + llvm::errs() << instr->ToString(); + return tensorflow::errors::Internal( + absl::StrCat("LHLO opcode ", ::xla::HloOpcodeString(instr->opcode()), + " is not supported.")); + } + return Status::OK(); +} + Value LhloDialectEmitter::GetOrCreateView( const BufferAllocation::Slice& slice) { // Check if we already have a view for this slice, otherwise we need to create @@ -177,17 +251,15 @@ Value LhloDialectEmitter::GetOrCreateView( // Create the view for this slice size, possible with an affine map to model // the offset. The result is cached in the slices_ map. - SmallVector offset_map; - if (slice.offset()) { - offset_map.push_back(AffineMap::get( - /*dimCount=*/1, /*symbolCount=*/0, - {getAffineDimExpr(0, builder_.getContext()) + slice.offset()}, - builder_.getContext())); - } - auto slice_type = MemRefType::get({slice.size()}, i8_type_, offset_map); + // The std.view result type does not carry the static offset: this is not + // useful information. Rather, the view op must have the static offset. + auto slice_type = MemRefType::get({slice.size()}, i8_type_, {}); - auto slice_view = builder_.create( - alloc_buffer.getLoc(), slice_type, alloc_buffer, /*operands=*/llvm::None); + Value byte_shift = + builder_.create(alloc_buffer.getLoc(), slice.offset()); + auto slice_view = + builder_.create(alloc_buffer.getLoc(), slice_type, alloc_buffer, + byte_shift, /*sizes=*/ArrayRef{}); slices_.insert({slice_key, slice_view}); return slice_view; } @@ -203,9 +275,12 @@ StatusOr LhloDialectEmitter::GetOrCreateView( Value slice_view = GetOrCreateView(out_slice); TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType( target_shape, builder_)); + Value byte_shift = + builder_.create(builder_.getUnknownLoc(), 0); if (slice_view.getType() != out_type) - slice_view = builder_.create(builder_.getUnknownLoc(), out_type, - slice_view, llvm::None); + slice_view = + builder_.create(builder_.getUnknownLoc(), out_type, slice_view, + byte_shift, /*sizes=*/ArrayRef{}); return slice_view; } @@ -334,8 +409,7 @@ Status ConvertModule(ModuleOp module, StringRef platform_name) { module.ensureTerminator(module.getBodyRegion(), builder, module.getLoc()); TF_RETURN_WITH_CONTEXT_IF_ERROR( - LhloDialectEmitter::EmitModule(*assignment, *optimized_hlo_module, - module), + HloToLhloModule(*assignment, *optimized_hlo_module, module), "converting HLO to LHLO"); return Status::OK(); @@ -372,6 +446,11 @@ std::unique_ptr> createXlaHloToLhloWithXlaPass() { return std::make_unique(); } +Status HloToLhloModule(const BufferAssignment& assignment, + const HloModule& hlo_module, ModuleOp module) { + return LhloDialectEmitter(assignment, hlo_module, module).Run(); +} + static PassRegistration registration( "xla-hlo-to-lhlo-with-xla", "Emit LHLO from HLO using the existing XLA implementation"); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h new file mode 100644 index 00000000000..1018bdbf408 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ + +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace mlir { + +// Populate the MLIR `module` with the computation from the `hlo_module` using +// the provided buffer `assignment`. The returned `Status` indicates success +// or failure in the conversion. +tensorflow::Status HloToLhloModule(const xla::BufferAssignment& assignment, + const xla::HloModule& hlo_module, + ModuleOp module); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_HLO_TO_LHLO_WITH_XLA_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 1a206d5d8a3..799a20aa693 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -84,7 +84,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { emitError(loc, "lhlo to linalg conversion expects ranked args"); return failure(); } - if (!argType.getElementType().isSignlessIntOrFloat()) { + auto elemTy = argType.getElementType(); + if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa()) { return failure(); } @@ -284,34 +285,32 @@ class BroadcastInDimConverter broadcastOp.operand().getType().template cast(); unsigned nloops = resultType.getRank(); + // The input is a scalar, i.e. this is a scalar broadcast op. + if (operandType.getRank() == 0) { + return b->getAffineMapArrayAttr( + {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); + } + auto operandShape = operandType.getShape(); SmallVector dimExprs; - AffineMap inputMap = AffineMap::get(b->getContext()); - { - dimExprs.reserve(nloops); + dimExprs.reserve(nloops); - if (broadcastOp.broadcast_dimensions()) { - for (const auto& broadcastDim : - enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { - int size = broadcastDim.value().getSExtValue(); - // TODO(pifon): Add support for args with dynamic shapes for the case - // when a dimension of size 1 is broadcasted into dim of size N. - AffineExpr affineExpr = operandShape[broadcastDim.index()] == 1 - ? b->getAffineConstantExpr(0) - : b->getAffineDimExpr(size); - dimExprs.push_back(affineExpr); - } - } - if (dimExprs.empty()) { - // The input is a scalar, i.e. this is a scalar broadcast op. - inputMap = AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()); - } else { - inputMap = AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, - b->getContext()); + if (broadcastOp.broadcast_dimensions()) { + for (const auto& broadcastDim : + enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { + int size = broadcastDim.value().getSExtValue(); + bool expansion_needed = operandShape[broadcastDim.index()] == 1 && + resultType.getShape()[size] != 1; + // TODO(pifon): Add support for args with dynamic shapes for the case + // when a dimension of size 1 is broadcasted into dim of size N. + dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) + : b->getAffineDimExpr(size)); } } return b->getAffineMapArrayAttr( - {inputMap, b->getMultiDimIdentityMap(nloops)}); + {AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); } }; @@ -620,21 +619,25 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -717,18 +720,23 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 538b0cf492d..ea4ba8dab6b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -128,6 +128,7 @@ tf_xla_py_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -165,6 +166,7 @@ tf_xla_py_test( srcs = ["add_n_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -225,6 +227,7 @@ tf_xla_py_test( name = "complex_div_test", size = "medium", srcs = ["complex_div_test.py"], + enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -449,6 +452,7 @@ tf_xla_py_test( name = "clustering_test", size = "small", srcs = ["clustering_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -466,6 +470,7 @@ tf_xla_py_test( name = "concat_ops_test", size = "medium", srcs = ["concat_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "many_xla_args", @@ -488,6 +493,7 @@ tf_xla_py_test( name = "conv2d_test", size = "medium", srcs = ["conv2d_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -510,6 +516,7 @@ tf_xla_py_test( name = "conv3d_test", size = "medium", srcs = ["conv3d_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -555,6 +562,7 @@ tf_xla_py_test( name = "dynamic_slice_ops_test", size = "small", srcs = ["dynamic_slice_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -571,6 +579,7 @@ tf_xla_py_test( name = "einsum_op_test", size = "medium", srcs = ["einsum_op_test.py"], + enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -592,6 +601,7 @@ tf_xla_py_test( name = "reshape_op_test", size = "small", srcs = ["reshape_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -663,6 +673,7 @@ tf_xla_py_test( name = "fifo_queue_test", size = "medium", srcs = ["fifo_queue_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -702,6 +713,7 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -737,6 +749,7 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -881,6 +894,7 @@ tf_xla_py_test( name = "nary_ops_test", size = "small", srcs = ["nary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -898,6 +912,7 @@ tf_xla_py_test( name = "nullary_ops_test", size = "small", srcs = ["nullary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1220,6 +1235,7 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1280,6 +1296,7 @@ tf_xla_py_test( srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1308,6 +1325,7 @@ tf_xla_py_test( srcs = ["tensor_list_ops_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1326,6 +1344,7 @@ tf_xla_py_test( name = "ternary_ops_test", size = "medium", srcs = ["ternary_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1368,6 +1387,7 @@ tf_xla_py_test( size = "medium", srcs = ["fused_batchnorm_test.py"], python_version = "PY3", + shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -1501,6 +1521,7 @@ tf_xla_py_test( name = "data_format_ops_test", size = "small", srcs = ["data_format_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1735,6 +1756,7 @@ tf_xla_py_test( name = "placeholder_test", size = "small", srcs = ["placeholder_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1791,6 +1813,7 @@ tf_xla_py_test( name = "conv_node_name_test", size = "medium", srcs = ["conv_node_name_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1837,6 +1860,7 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], + enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index d9721a3c8ac..00ed6d83e2e 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -73,8 +73,6 @@ class BinaryOpsTest(xla_test.XLATestCase): self.assertAllCloseAccordingToType( result[i], expected[i], rtol=rtol, atol=atol) - @test_util.disable_mlir_bridge( - "F16 type is not supported in CreateDenseElementsAttrFromLiteral") def testFloatOps(self): for dtype in self.float_types: if dtype == dtypes.bfloat16.as_numpy_dtype: @@ -299,7 +297,6 @@ class BinaryOpsTest(xla_test.XLATestCase): ] self._testBinary(bitwise_ops.right_shift, lhs, rhs, expected=expected) - @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints") def testAdd(self): for dtype in self.numeric_types: self._testBinary( @@ -326,7 +323,6 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([3.0269620882574744, 3.3149631512242195], dtype=dtype)) - @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints") def testMultiply(self): for dtype in self.numeric_types: self._testBinary( @@ -390,7 +386,6 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([[16], [81]], dtype=dtype), rtol=rtol) - @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints") def testNumericOps(self): for dtype in self.numeric_types: self._testBinary( @@ -934,7 +929,6 @@ class BinaryOpsTest(xla_test.XLATestCase): expected = np.array([op(l, r) for l, r in zip(lhs, rhs)], dtype=np.bool) self._testBinary(op, lhs, rhs, expected=expected) - @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints") def testBroadcasting(self): """Tests broadcasting behavior of an operator.""" @@ -1230,6 +1224,8 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Requires concatenate op support in MlirHloBuilder") def testSymmetricMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") for dtype in self.numeric_types: @@ -1261,6 +1257,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[0, 0], [0, 0]], dtype=np.int32), expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Requires concatenate op support in MlirHloBuilder") def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: @@ -1414,6 +1412,7 @@ class BinaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + @test_util.disable_mlir_bridge("TODO(b/155097657): Debug incorrect answer") def testTile(self): for dtype in self.numeric_types: self._testBinary( @@ -1502,7 +1501,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype)) - @test_util.disable_mlir_bridge("Enable tf.Cross Compilation") def testCross(self): for dtype in self.float_types: self._testBinary( @@ -1572,6 +1570,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([2, 1, 5], dtype=np.int32), expected=np.array([2, 3, 5], dtype=np.int32)) + @test_util.disable_mlir_bridge("Error handling") + def testBroadcastArgsError(self): with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, "Incompatible shapes"): self._testBinary(array_ops.broadcast_dynamic_shape, @@ -1579,6 +1579,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) + @test_util.disable_mlir_bridge( + "Requires BroadcastInDim method in MlirHloBuilder") def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 10dd2d6542c..f35ded924d5 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gradients_impl @@ -293,6 +294,7 @@ class ConcatTest(xla_test.XLATestCase): # The purpose of this is to ensure that XLA on GPU will not run out of memory # with too many arguments. + @test_util.disable_mlir_bridge("TODO(b/153895138): Debug.") def testConcatLargeNumberOfTensors(self): if "CPU" in self.device: self.skipTest("This test can time out on CPU, so we will just allow " diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 70377af6bdc..90ac515764b 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -38,7 +38,6 @@ class GatherNdTest(xla_test.XLATestCase): feed_dict = {paramsp: params, indicesp: indices} return gather_nd_t.eval(feed_dict=feed_dict) - @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints") def testSimpleDtype(self): for dtype in self.numeric_types: self.assertAllEqual( @@ -47,6 +46,7 @@ class GatherNdTest(xla_test.XLATestCase): np.array([8, 1, 2, 3, 7, 5], dtype=dtype), np.array([[4], [4], [0]], np.int32))) + @test_util.disable_mlir_bridge("Error handling") def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): with self.session(): params = np.ones((3, 3), dtype=np.float32) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index b89472b8085..81779203955 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -30,7 +30,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -979,7 +978,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1017,7 +1015,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1051,7 +1048,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([3, 3], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSSingleFrom6Max3(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1082,7 +1078,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual(3, num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSSingleFrom6NoPad(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1112,7 +1107,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual(5, num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSBatchDimsFrom6Max3(self): boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1146,7 +1140,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[3, 3]], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSScoreThresholdFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1182,7 +1175,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSUnsortedInputFrom6(self): boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], @@ -1219,7 +1211,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSNoncanonicalizedInputFrom6(self): boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], @@ -1257,7 +1248,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1293,7 +1283,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.with_forward_compatibility_horizons(None, [2020, 4, 21]) def testBatchedNMSFrom6DynamicInput(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 465f368db82..a1bb64eb88d 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -24,6 +24,7 @@ import scipy.special as sps from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -47,6 +48,8 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): {'start': 1, 'end': 2, 'num': 1}, {'start': 1, 'end': 4, 'num': 3}, {'start': 0, 'end': 41, 'num': 42}) + @test_util.disable_mlir_bridge( + 'TODO(b/156174708): Dynamic result types not supported') def testLinspace(self, start, end, num): expected = np.linspace(start, end, num, dtype=np.float32) result = self._testTernary( @@ -74,6 +77,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): np.int32(2), expected=np.array([1, 3, 5], dtype=np.int32)) + @test_util.disable_mlir_bridge('TODO(b/155949336)') def testSelect(self): for dtype in self.numeric_types: self._testTernary( @@ -211,6 +215,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): upper, expected=np.minimum(np.maximum(x, lower), upper)) + @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetaincSanity(self): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: @@ -230,7 +235,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): { 'sigma': 1e15, 'rtol': 1e-6, - 'atol': 1e-6 + 'atol': 1e-4 }, { 'sigma': 30, @@ -240,7 +245,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): { 'sigma': 1e-8, 'rtol': 5e-4, - 'atol': 3e-6 + 'atol': 3e-4 }, { 'sigma': 1e-16, @@ -248,6 +253,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): 'atol': 2e-4 }, ) + @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetainc(self, sigma, rtol, atol): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index cd9ba983785..3e36f67615b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -186,8 +186,6 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.Softmax compilation") def testFloatOps(self): for dtype in self.float_types: x = np.arange(-0.90, 0.90, 0.25) @@ -514,6 +512,11 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") + def testQuantizeAndDequantize(self): + for dtype in self.float_types: + def quantize_and_dequantize_v2(x): return array_ops.quantize_and_dequantize_v2( x, -127, 127, signed_input=True, num_bits=8) @@ -598,8 +601,7 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) - @test_util.disable_mlir_bridge( - "Complex types not supported in CreateDenseElementsAttrFromLiteral") + @test_util.disable_mlir_bridge("TODO(b/156135423): Fix ConvertSigmoidOp") def testComplexOps(self): for dtype in self.complex_types: @@ -757,7 +759,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) - @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints") def testIntOps(self): for dtype in self.int_types: self._assertOpOutputMatchesExpected( diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index df388c655d0..f3e915daa67 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -51,7 +51,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): equality_fn = self.assertAllClose equality_fn(result, expected, rtol=1e-3) - @test_util.disable_mlir_bridge('Not supported yet') def testAdd(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -72,7 +71,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): np.array([7, 11], dtype=dtype)), expected=np.array([[8, 13], [10, 15]], dtype=dtype)) - @test_util.disable_mlir_bridge('Not supported yet') def testBroadcast(self): for dtype in self.numeric_types: v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) @@ -81,7 +79,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(v,), expected=np.tile(v, (7, 42, 1, 1))) - @test_util.disable_mlir_bridge('Unsigned ints are not supported yet') + @test_util.disable_mlir_bridge('Dynamic result types not supported') def testShiftRightLogical(self): self._assertOpOutputMatchesExpected( xla.shift_right_logical, @@ -93,7 +91,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) - @test_util.disable_mlir_bridge('Unsigned ints are not supported yet') + @test_util.disable_mlir_bridge('Dynamic result types not supported') def testShiftRightArithmetic(self): self._assertOpOutputMatchesExpected( xla.shift_right_arithmetic, @@ -110,7 +108,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): xla_data_pb2.PrecisionConfig.HIGHEST) @parameterized.parameters(*PRECISION_VALUES) - @test_util.disable_mlir_bridge('Not supported yet') def testConv(self, precision): for dtype in set(self.float_types).intersection( set([dtypes.bfloat16.as_numpy_dtype, np.float32])): @@ -195,7 +192,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([1, 2, 3], dtype=dtype),), expected=np.array([-1, -2, -3], dtype=dtype)) - @test_util.disable_mlir_bridge('Not supported yet') def testPad(self): for dtype in self.numeric_types: @@ -320,6 +316,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [[673, 674], [683, 684], [693, 694]]]), dtype=dtype)) + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectStartIndicesShape(self): with self.session() as session: with self.test_scope(): @@ -333,6 +330,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) + @test_util.disable_mlir_bridge('Error handling') def testDynamicSliceWithIncorrectSizeIndicesShape(self): with self.session() as session: with self.test_scope(): diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 8ca30479330..356798c19bd 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -496,6 +496,7 @@ cc_library( "//tensorflow/core/grappler/costs:graph_properties", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf_headers", ], ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index a90ac172c32..a43b16e9e6a 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1456,12 +1456,13 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, absl::string_view name, nvinfer1::ITensor** output_tensor) { const auto dims = input_tensor->getDimensions(); - - if (order_with_batch_dim.size() - 1 != size_t(dims.nbDims)) { + const int order_size = use_implicit_batch_ ? order_with_batch_dim.size() - 1 + : order_with_batch_dim.size(); + if (order_size != size_t(dims.nbDims)) { return errors::InvalidArgument( "Rank of perm for transpose does not match with that of the input."); } - if (order_with_batch_dim[0] != 0) { + if (use_implicit_batch_ && order_with_batch_dim[0] != 0) { return errors::Unimplemented( "Transpose at batch dimension is not supported."); } @@ -1472,8 +1473,13 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0)); nvinfer1::Permutation permutation; - for (int32_t i = 0; i < dims.nbDims; ++i) { - permutation.order[i] = order_with_batch_dim[i + 1] - 1; + if (use_implicit_batch_) { + for (int32_t i = 0; i < dims.nbDims; ++i) { + permutation.order[i] = order_with_batch_dim[i + 1] - 1; + } + } else { + std::copy(order_with_batch_dim.begin(), order_with_batch_dim.end(), + permutation.order); } VLOG(1) << "TransposeTensor permutation: " << DebugString(permutation, dims.nbDims); @@ -2271,11 +2277,13 @@ Status ConvertTranspose(OpConverterParams* params) { // Verify the permutation. nvinfer1::ITensor* input_tensor = inputs.at(0).tensor(); - if (perm.size() - 1 != size_t(input_tensor->getDimensions().nbDims)) { + const int perm_size = + params->use_implicit_batch ? perm.size() - 1 : perm.size(); + if (perm_size != size_t(input_tensor->getDimensions().nbDims)) { return errors::InvalidArgument( "Rank of perm for transpose does not match with that of the input."); } - if (perm[0] != 0) { + if (params->use_implicit_batch && perm[0] != 0) { return errors::Unimplemented( "Transpose at batch dimension is not supported."); } @@ -2405,26 +2413,19 @@ Status ConvertExpandDims(OpConverterParams* params) { } Status Converter::SqueezeTensor(nvinfer1::ITensor* input, - const std::vector& trt_axes, + std::vector* input_dims, nvinfer1::ITensor** output) { - const nvinfer1::Dims dims = input->getDimensions(); - std::vector input_dims(dims.d, dims.d + dims.nbDims); - // Mark axes to remove by setting them to 0. - for (int axis : trt_axes) { - input_dims[axis] = 0; - } - #if IS_TRT_VERSION_GE(6, 0, 0, 0) // If the remaining dimensions of a squeeze operation have dynamic sizes, we // need to use TRT ops to build the result shape for the squeeze operation. // This is because IShuffleLayer::setReshapeDimensions treats -1 as a special // value. - if (absl::c_any_of(input_dims, [](int i) { return i == -1; })) { + if (absl::c_any_of(*input_dims, [](int i) { return i == -1; })) { nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0); std::vector concat_inputs; - for (int i = 0; i < input_dims.size(); i++) { + for (int i = 0; i < input_dims->size(); i++) { // If input dim wasn't set to 0 earlier, we include it in new shape. - if (input_dims[i] != 0) { + if (input_dims->at(i) != 0) { concat_inputs.push_back( network() ->addSlice(*shape, {1, {i}}, {1, {1}}, {1, {1}}) @@ -2444,11 +2445,12 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input, } #endif // Remove all dims which are equal to 0. - input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), - input_dims.end()); + input_dims->erase(std::remove(input_dims->begin(), input_dims->end(), 0), + input_dims->end()); // Reshape tensor. nvinfer1::Dims new_dims; - TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); + VLOG(2) << "input_dims" << input_dims; + TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims)); TF_RETURN_IF_ERROR(PrepareTensorForShape(TRT_TensorOrWeights(input), new_dims, /*validation_only=*/false, output)); return Status::OK(); @@ -2467,31 +2469,48 @@ Status ConvertSqueeze(OpConverterParams* params) { TFAttrs attrs(node_def); auto squeeze_dims = attrs.get>("squeeze_dims"); if (squeeze_dims.empty()) { - return errors::Unimplemented( - "Squeeze is only implemented for explicit dims, at ", node_def.name()); - } - std::vector trt_axes; - trt_axes.reserve(squeeze_dims.size()); - for (int tf_axis : squeeze_dims) { - // If the axis is valid, then convert it to TRT axis, otherwise abort - // conversion. - int trt_axis; - TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), - params->use_implicit_batch, &trt_axis)); - // Make sure target dimension is size 1 or unknown size (-1) - if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) { - return errors::InvalidArgument( - "Dimension ", tf_axis, " with size ", input_dims[trt_axis], - " cannot be squeezed because it must be size 1, at ", + if (params->use_implicit_batch || !HasStaticShape(dims)) { + return errors::Unimplemented( + "Squeeze is not implemented for empty squeeze_dims, at ", node_def.name()); + } else { + // explicit batch mode with static input shape we squeeze all singleton + // dimensions + for (int& dim : input_dims) { + if (dim == 1) { + // Mark it for removal by setting it to 0 + dim = 0; + } + } + } + } else { + std::vector trt_axes; + trt_axes.reserve(squeeze_dims.size()); + for (int tf_axis : squeeze_dims) { + // If the axis is valid, then convert it to TRT axis, otherwise abort + // conversion. + int trt_axis; + TF_RETURN_IF_ERROR(ConvertAxis(tf_axis, dims.nbDims, node_def.name(), + params->use_implicit_batch, &trt_axis)); + // Make sure target dimension is size 1 or unknown size (-1) + if (input_dims[trt_axis] != -1 && input_dims[trt_axis] != 1) { + return errors::InvalidArgument( + "Dimension ", tf_axis, " with size ", input_dims[trt_axis], + " cannot be squeezed because it must be size 1, at ", + node_def.name()); + } + trt_axes.push_back(trt_axis); + } + // Mark axes to remove by setting them to 0. + for (int axis : trt_axes) { + input_dims[axis] = 0; } - trt_axes.push_back(trt_axis); } if (params->validation_only) return Status::OK(); nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->SqueezeTensor( - input_tensor.tensor(), trt_axes, &output_tensor)); + input_tensor.tensor(), &input_dims, &output_tensor)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 8608c8226ee..2092aecd657 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -529,11 +529,9 @@ class Converter { // Helper function to add a squeeze op to the network. // - // The trt_axes argument lists those axes that need to be squeezed. Each axis - // in the list is numbered according to TRT convention (see ConvertAxis for - // details). - Status SqueezeTensor(nvinfer1::ITensor* input, - const std::vector& trt_axes, + // The input_dims argument stores the TRT dimensions of the input tensor, + // where the dimensions to be squeezed are replaced by 0. + Status SqueezeTensor(nvinfer1::ITensor* input, std::vector* input_dims, nvinfer1::ITensor** output); // Creates an IConstantLayer using 'weights' whose dimensions are specified by diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 3e9c5db80d0..884ed7a5771 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include +#include #include #include #include @@ -24,6 +26,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -64,8 +67,45 @@ namespace convert { using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; +using ::testing::FloatNear; +using ::testing::Matcher; using ::testing::NanSensitiveFloatNear; +// TensorRT modes for testing. We define the following three modes: +// 1. Implicit batch mode: The tensors have static (known) input shape and the +// the batch dimension (first dim) is removed from the TRT tensor shape. In +// a loose notation: trt_shape = tf_shape[1:]. This is the standard mode of +// a TensorRT network definition before TensorRT 6. +// 2. Explicit batch mode: static (known) input shape, but the batch dimension +// is part of the trt tensor shape. (trt_shape = tf_shape) +// 3. Dynamic shape mode allows unknown input shapes, and requires explicit +// batch size definition (trt_shape = tf_shape). +// +// Note that the Converter only distinguishes between two modes: +// - use_implicit_batch == true, this corresponds to kImplicitBatch, +// - use_implicit_batch == false which includes both kExplicitBatch and +// kDynamicShape. +// +// For the converter, the distinction between explicit batch or dynamic shape +// mode follows from the input tensors of the network: dynamic shape input +// implies dynamic shape mode, while static shape input tensors imply explicit +// batch mode. We want to test all these modes, therefore we define the +// TrtTestMode with the following three options. +enum class TrtTestMode { + kImplicitBatch = 0, + kExplicitBatch = 1, + kDynamicShape = 2 +}; + +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +constexpr std::array ValidTrtModes = { + TrtTestMode::kImplicitBatch, TrtTestMode::kExplicitBatch, + TrtTestMode::kDynamicShape}; +#else +constexpr std::array ValidTrtModes = { + TrtTestMode::kImplicitBatch}; +#endif + // TODO(laigd): put this into some test utils file. void ExpectStatus(Status status, error::Code code = error::OK, const char* substr = nullptr) { @@ -86,6 +126,17 @@ nvinfer1::Dims GetTestDims(const std::vector& d) { return dims; } +// Prints the vector to the output stream. +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) { + if (!v.empty()) { + os << '['; + std::copy(v.begin(), v.end(), std::ostream_iterator(os, ", ")); + os << "\b\b]"; + } + return os; +} + nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { switch (tf_dtype) { case DT_FLOAT: @@ -167,6 +218,21 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, << " actual: " << DebugString(rhs); } +Matcher> ArrayFloatNear(const std::vector& values, + float max_abs_error = 1e-5, + bool nan_sensitive = false) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + if (nan_sensitive) { + matchers.emplace_back(NanSensitiveFloatNear(v, max_abs_error)); + } else { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + } + return ElementsAreArray(matchers); +} + template void ExpectArrayNear(const std::vector& lhs, absl::Span rhs) { ASSERT_EQ(lhs.size(), rhs.size()); @@ -1217,6 +1283,17 @@ TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) { TF_EXPECT_OK(RunConvertGraphDefToEngine(&s)); } +// Returns a vector of shapes from a vector of input tensors. This can be used +// to create optimization profiles. +Status GetShapeFromDataVec(DataVec input_data, + std::vector* shape_vec) { + shape_vec->reserve(input_data.size()); + std::transform(input_data.begin(), input_data.end(), + std::back_inserter(*shape_vec), + [](InputOutputData x) { return x.tensor.shape(); }); + return Status::OK(); +} + template inline absl::Span GetSpanForData(const InputOutputData& data) { const auto& tensor_map = data.tensor.flat(); @@ -1239,16 +1316,18 @@ class OpConverterTest : public ::testing::Test { return converter_->GetTensorOrWeights(name, output); } - void Reset() { + void Reset(TrtPrecisionMode precision_mode_to_test = TrtPrecisionMode::FP32, + TrtTestMode trt_mode = TrtTestMode::kImplicitBatch) { // Destroy existing TRT objects in a proper order. converter_.reset(nullptr); engine_.reset(nullptr); // Re-create them in proper order. converter_ = - std::move(Converter::Create(precision_mode_to_test_, + std::move(Converter::Create(precision_mode_to_test, /*use_calibration=*/false, &logger_, - /*use_implicit_batch=*/true) + /*use_implicit_batch=*/trt_mode == + TrtTestMode::kImplicitBatch) .ValueOrDie()); // Reset other related artifacts. @@ -1294,9 +1373,7 @@ class OpConverterTest : public ::testing::Test { } } - // TODO(laigd): test fp16 and int8 support for more converters. void BuildAndRun(const DataVec& input_data, DataVec* output_data, - TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32, const int batch_size = 1) { // Mark the output tensor as TRT engine output. std::vector output_info; @@ -1308,13 +1385,21 @@ class OpConverterTest : public ::testing::Test { // Build the TRT engine. ASSERT_EQ(nullptr, engine_.get()); + TrtShapeOptimizationProfile profiles; + if (!converter_->use_implicit_batch()) { + // Create a single optimization profile for explicit batch mode + std::vector input_shapes; + TF_ASSERT_OK(GetShapeFromDataVec(input_data, &input_shapes)); + profiles.AddShape(input_shapes); + profiles.InitProfiles(); + } TF_ASSERT_OK( converter_->BuildCudaEngine(&engine_, /*max_batch_size=*/batch_size, /*max_workspace_size_bytes=*/1 << 26, /*allocator=*/nullptr, /*calibrator=*/nullptr, - /*profiles=*/nullptr)); + /*profiles=*/&profiles)); CHECK_NOTNULL(engine_.get()); CheckDataTypeMatches(input_data); CheckDataTypeMatches(*output_data); @@ -1323,6 +1408,9 @@ class OpConverterTest : public ::testing::Test { std::vector buffers(num_bindings); ASSERT_EQ(engine_->getNbBindings(), num_bindings); + // Since we have only 1 optimization profile (which is enabled by default) + // it is fine to create execution context directly, instead of calling + // profiles.CreateExecutionContexts() TrtUniquePtrType execution_context( engine_->createExecutionContext()); @@ -1350,22 +1438,81 @@ class OpConverterTest : public ::testing::Test { return true; } - // Add ITensor for both validation and conversion. - void AddTestTensor( - const string& name, const std::vector& dims, int batch_size = 1, + bool HasStaticShape(std::vector dims) const { + return !absl::c_any_of(dims, [](int i) { return i < 0; }); + } + + // Adds ITensor for both validation and conversion, assuming explicit batch + // dimension is included in dims (ie for an NCHW tensor dims = {N, C, H, W}). + void AddTestTensorWithExplicitBatchDim( + const string& name, const std::vector& dims, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { DataType tf_dtype = TrtDataTypeToTf(trt_dtype); ops::Placeholder::Attrs attrs; TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_)); - attrs.shape_.InsertDim(0, batch_size); + auto input = ops::Placeholder(scope_.WithOpName(name), tf_dtype, attrs); node_inputs_[name] = input.output; // Add a real ITensor for conversion conditionally. - const nvinfer1::Dims trt_dims = GetTestDims(dims); - if (HasStaticShape(trt_dims)) { + const nvinfer1::Dims trt_dims = + TensorShapeToTrtDims(attrs.shape_, converter_->use_implicit_batch()); + if (!converter_->use_implicit_batch() || HasStaticShape(trt_dims)) { + int batch_size = dims[0]; TF_EXPECT_OK( converter_->AddInputTensor(name, trt_dtype, trt_dims, batch_size)); + } + } + + // Adds ITensor for both validation and conversion. The tensor can have + // partial input shape. This function defines static or dynamic shape input + // tensor for the network based on the trt_mode attribute. This is done + // automatically, unless the user overrides it with an explicit + // partial_input_shape_dims argument. + // + // Parameters: + // - dims actual dimensions of the tensor that we will use during the test + // (including explicit batch dim). This is not used if partial_input_shape + // is defined. + // - partial_input_shape dimensions which can incude unknown shapes. This can + // be empty, in that case the partial_input_shape will be set automatically + // depending on the trt_mode argument. (This also includse explicit batch + // dim). + // + // On return skip_test is false if trt_mode is not compatible with the + // partial input shape. + void AddTestTensor( + const string& name, const std::vector& dims, + nvinfer1::DataType trt_dtype, TrtTestMode trt_mode, + const std::vector* partial_input_shape_dims = nullptr) { + std::vector partial_shape; + if (partial_input_shape_dims && !partial_input_shape_dims->empty()) { + partial_shape = *partial_input_shape_dims; + } else { + if (trt_mode == TrtTestMode::kDynamicShape) { + // In dynamic shape mode we set the all dims unknown. + partial_shape = std::vector(dims.size(), -1); + } else { + // Use static (known) input shapes. + partial_shape = dims; + } + } + AddTestTensorWithExplicitBatchDim(name, partial_shape, trt_dtype); + } + + // Adds ITensor for both validation and conversion. The difference compared to + // AddTestTensorWithExplicitBatchDim is in the meaning of the dims parameter. + // To define a tensor with NCHW shape, here we set dims = {C,H,W} and + // batch_size = N. TODO(tfeher) remove this function once all test are updated + // to use the other version of AddTestTensor which has the trt_mode arg. + void AddTestTensor( + const string& name, const std::vector& dims, int batch_size = 1, + nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { + std::vector dims_with_batch(dims.size() + 1); + dims_with_batch[0] = batch_size; + std::copy(dims.begin(), dims.end(), dims_with_batch.begin() + 1); + AddTestTensorWithExplicitBatchDim(name, dims_with_batch, trt_dtype); + if (HasStaticShape(dims)) { ASSERT_EQ(batch_size, converter_->batch_size_); } } @@ -1405,9 +1552,9 @@ class OpConverterTest : public ::testing::Test { grappler::GraphProperties graph_properties(item); TF_EXPECT_OK(graph_properties.InferStatically(true)); - TrtNodeValidator validator(graph_properties, precision_mode_to_test_, + TrtNodeValidator validator(graph_properties, converter_->precision_mode(), /*use_calibration=*/false, - /*use_implicit_batch=*/true); + converter_->use_implicit_batch()); ExpectStatus(validator.IsTensorRTCandidate(node), expected_code, expected_msg_substr); } @@ -1446,6 +1593,33 @@ class OpConverterTest : public ::testing::Test { } } + // Helper method to run both validation and conversion, and check the output + // shape. + void RunValidationAndConversion(const NodeDef& node_def, const Status& status, + const char* output_name, + const std::vector& exp_out_dims) { + RunValidationAndConversion(node_def, status.code(), + status.error_message().c_str(), true); + if (status.ok()) { + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights(output_name, &output)); + ASSERT_TRUE(output.is_tensor()); + if (converter_->use_implicit_batch() && !exp_out_dims.empty()) { + // We only check output shape implicit batch mode. In dynamic shape + // mode we need to wait for the concrate input shapes to be defined + // (by setBindingDimensions before enqueue) before we can check + // whether the output dims are equal. + // + // TODO(tamas) enable this check in explicit_batch_mode + + // Removing batch dim + auto out_dims = + std::vector(exp_out_dims.begin() + 1, exp_out_dims.end()); + ExpectTrtDimsEqualsArray(out_dims, output.tensor()->getDimensions()); + } + } + } + // Expose quantization_ranges_ for tests std::unordered_map& quantization_ranges() { return converter_->quantization_ranges_; @@ -1456,10 +1630,6 @@ class OpConverterTest : public ::testing::Test { } std::unique_ptr converter_; - protected: - // TODO(laigd): parameterize the test and make the precision mode a parameter. - TrtPrecisionMode precision_mode_to_test_ = TrtPrecisionMode::FP32; - private: Logger logger_; TrtUniquePtrType engine_; @@ -1473,6 +1643,127 @@ class OpConverterTest : public ::testing::Test { std::unique_ptr allocator_; }; +// General test parameters to be used with ops that take a single input tensor. +struct TestParamBase { + // Concrete input dimensions for the test (including the batch dim) + std::vector input_dims; + + // Dimensions to define an input with PartialTensorShape. This can be used to + // define networks with dynamic input shape. It can be left empty, in that + // case AddTestTensor sets partial shapes that are appropriate to TrtTestMode. + std::vector partial_input_dims; + + // Concrete (static) output dimensions, including batch size as first dim + std::vector expected_output_dims; + + // Parameter vector, has converter specific meaning. + std::vector param; + + // Expected status of conversion (with concrete error message) + Status status; + + // Expected status of BuildAndRun + Status runtime_status; +}; + +std::ostream& operator<<(std::ostream& os, const TestParamBase& p) { + os << "input_dims" << p.input_dims; + if (!p.partial_input_dims.empty()) { + os << ", partial_input_dims" << p.partial_input_dims; + } + if (!p.expected_output_dims.empty()) { + os << ", exp_out_dims" << p.expected_output_dims; + } + if (!p.param.empty()) { + os << ", param" << p.param; + } + os << ", " << p.status; + return os; +} + +// Parameterized version of OpConverterTest. This class will be instantiated +// to test all the TrtTestModes but only in FP32 precision. This means that we +// will use the following combinations of test parameters: +// 1. TrtTestMode: implicit batch, explicit batch, dynamic shape modes +// 2. DataType of the input TF tensors: DT_FLOAT +// 3. TrtPrecisionMode argument for the Converter: FP32 +class ParameterizedOpConverterTest + : public OpConverterTest, + public ::testing::WithParamInterface< + std::tuple> {}; + +// Instantiate parameter combinations to test. For debugging purposes it might +// make sense to run over all possible combinations, but normally a subset of +// them would be sufficient: +// - All valid options to TrtTestMode (implicit, explicit, dynamic shape) +// - DataType: is the TF data type of the input tensors. This usually only +// influences the data type added by Converter::AddInputTensor. We test the +// valid combinations of input data types in AddAndGetInputs, therefore +// for most of the OpConverterTest its is sufficient to test for DT_FLOAT. +// - TrtPrecisionMode: valid options are FP32, FP16 and INT8. This influences +// how TRT handles the precision inside the TRT network, but should not matter +// for the TF -> TRT conversion. Therefore it should be sufficient to test +// for FP32. +INSTANTIATE_TEST_CASE_P( + OpConvTestInstantiation, ParameterizedOpConverterTest, + ::testing::Combine(::testing::ValuesIn(ValidTrtModes), + ::testing::Values(DT_FLOAT), + ::testing::Values(TrtPrecisionMode::FP32))); + +// Builds and runs the converted network. Checks output tensor shape. Tests +// output values using a matcher. +template +void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test, + const TestParamBase& p, + const std::vector& input_vec, + const Matcher>& matcher) { + if (!p.status.ok()) { + // conversion was not successful, we cannot run the network + return; + } + if (!p.runtime_status.ok()) { + // Runtime error is expected. This can happen if the operation is invalid + // for the actual input shape. Usually we catch these errors during + // conversion. If the network was defined with dynamic input shape than we + // have to postpone these steps until runtime. + // + // TODO(tfeher) Instead of early return, modify BuildAndRun to handle + // runtime errors. + return; + } + typedef typename EnumToDataType::Type T; + TensorShape shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape)); + const DataVec input_data{ + {"input", test->AsTensor(CastTestVector(input_vec), shape)}}; + DataVec output_data{{name, test->ConstructTensor(6)}}; + test->BuildAndRun(input_data, &output_data); + // Check the shape of the actual output tensor + TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape)); + EXPECT_TRUE(output_data[0].tensor.shape() == shape) + << "Expected shape: " << shape.DebugString() << ", actual shape" + << output_data[0].tensor.shape().DebugString(); + // Cast the output to float and compare to expected output + auto out_span = GetSpanForData(output_data[0]); + std::vector casted_output(out_span.begin(), out_span.end()); + EXPECT_THAT(casted_output, matcher); +} + +void InstantiateBuildAndRun(DataType tf_dtype, const string& name, + OpConverterTest* test, const TestParamBase& p, + const std::vector& input_vec, + const Matcher>& matcher) { + if (tf_dtype == DT_FLOAT) { + BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); + } else if (tf_dtype == DT_HALF) { + BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); + } else if (tf_dtype == DT_INT32) { + BuildAndRunConvertedNetwork(name, test, p, input_vec, matcher); + } else { + FAIL() << "Test not supported for " << tf_dtype; + } +} + template void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField* out) { out->Clear(); @@ -1610,56 +1901,72 @@ TEST_F(OpConverterTest, ConvertConst) { TestConvertConst(this); } -TEST_F(OpConverterTest, ConvertTranspose) { +TEST_P(ParameterizedOpConverterTest, ConvertTranspose) { + const auto& spec = GetParam(); + const TrtTestMode trt_mode = std::get<0>(spec); + // Data type of TF input tensors + const DataType tf_dtype = std::get<1>(spec); + // Precision mode used for TensorRT engine + TrtPrecisionMode converter_precision = std::get<2>(spec); + // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32); auto transpose = ops::Transpose(s.WithOpName("my_transpose"), input, weights); const NodeDef& node_def = transpose.operation.node()->def(); - { - // Permutation is a tensor, should fail. - Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestTensor("weights", {3}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "The input \"perm\" for Transpose must be a constant, at my_transpose"); + std::vector test_params = { + // For the first test we leave param empty. This signals to use a + // input as weight which will be invalid + TestParamBase{{1, 1, 2, 3}, + {}, + {}, + {}, + Status(error::UNIMPLEMENTED, + "The input \"perm\" for Transpose must be a " + "constant, at my_transpose")}, + TestParamBase{{1, 1, 2, 3}, + {}, + {}, + {0, 1, 2}, + Status(error::INVALID_ARGUMENT, + "Rank of perm for transpose does not match with " + "that of the input.")}, + // Transpose batch dim + TestParamBase{ + {1, 1, 2, 3}, + {}, + {3, 2, 1, 1}, + {3, 2, 1, 0}, + (trt_mode == TrtTestMode::kImplicitBatch) + ? Status(error::UNIMPLEMENTED, + "Transpose at batch dimension is not supported") + : Status::OK()}, + TestParamBase{{1, 1, 2, 3}, {}, {1, 3, 1, 2}, {0, 3, 1, 2}}, + }; + if (trt_mode == TrtTestMode::kDynamicShape) { + // Dynamic shape tests where some shapes are known + test_params.push_back(TestParamBase{ + {1, 1, 2, 3}, {-1, 1, 2, -1}, {1, 3, 1, 2}, {0, 3, 1, 2}}); } - { - // Transpose at batch dimension, should fail. - Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", {4}, {1, 0, 2, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "Transpose at batch dimension is not supported"); - } - { - // Permutation rank doesn't match, should fail. - Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", {3}, {0, 1, 2}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Rank of perm for transpose does not match with that of the input."); - } - { - // Ok. - Reset(); - AddTestTensor("input", {1, 2, 3}); - AddTestWeights("weights", {4}, {0, 3, 1, 2}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()); - - const DataVec input_data{{"input", AsTensor({1, 2, 3, 4, 5, 6})}}; - DataVec output_data{{"my_transpose", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(1, 4, 2, 5, 3, 6)); + std::vector expected_values{1, 4, 2, 5, 3, 6}; + for (auto p : test_params) { + SCOPED_TRACE(p); + Reset(converter_precision, trt_mode); + AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode, + &p.partial_input_dims); + if (p.param.empty()) { + AddTestTensor("weights", {3}); + } else { + AddTestWeights("weights", {static_cast(p.param.size())}, + p.param); + } + RunValidationAndConversion(node_def, p.status, "my_transpose", + p.expected_output_dims); + InstantiateBuildAndRun(tf_dtype, "my_transpose", this, p, + {1, 2, 3, 4, 5, 6}, + ElementsAreArray(expected_values)); } } @@ -1756,7 +2063,7 @@ TEST_F(OpConverterTest, ConvertReshape) { const DataVec input_data{{"input", AsTensor(input_vec)}}; DataVec output_data{ {"my_reshape", ConstructTensor(input_vec.size())}}; - BuildAndRun(input_data, &output_data, TrtPrecisionMode::FP32, batch_size); + BuildAndRun(input_data, &output_data, batch_size); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(input_vec)); } @@ -1908,28 +2215,24 @@ TEST_F(OpConverterTest, ConvertMatMul) { } { // Make sure that INT8 mode uses IFullyConnectedLayer when possible. - precision_mode_to_test_ = TrtPrecisionMode::INT8; - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false); AddTestTensor("input", {2, 1, 1}); AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); RunValidationAndConversion(node_def); CheckAddedLayers(this, false); CheckAddedLayers(this, true); - precision_mode_to_test_ = TrtPrecisionMode::FP32; } { // Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not // compatible. In this case we can't use FC because weights is a tensor. - precision_mode_to_test_ = TrtPrecisionMode::INT8; - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false); AddTestTensor("input", {2, 1, 1}); AddTestTensor("weights", {2, 2}); RunValidationAndConversion(node_def); CheckAddedLayers(this, true); CheckAddedLayers(this, false); - precision_mode_to_test_ = TrtPrecisionMode::FP32; } TestMatMulHelper(this, get_matmul_nodedef, "MatMul"); } @@ -1961,15 +2264,13 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) { { // Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not // compatible. In this case we can't use FC because transpose_a is true. - precision_mode_to_test_ = TrtPrecisionMode::INT8; - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = get_batch_matmul_nodedef(DT_FLOAT, true, false); AddTestTensor("input", {1, 2, 2}); AddTestWeights("weights", {2, 2}, {0, 1, 2, 3}); RunValidationAndConversion(node_def); CheckAddedLayers(this, true); CheckAddedLayers(this, false); - precision_mode_to_test_ = TrtPrecisionMode::FP32; } for (bool transpose_a : {false, true}) { @@ -2144,10 +2445,7 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, - /*batch_size=*/2); + test->BuildAndRun(input_data, &output_data, /*batch_size=*/2); if (node_def.op() == "Add") { EXPECT_THAT( GetSpanForData(output_data[0]), @@ -2281,10 +2579,7 @@ void TestAddN(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); DataVec output_data{{"my_addn", test->ConstructTensor(4)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, - /*batch_size=*/2); + test->BuildAndRun(input_data, &output_data, /*batch_size=*/2); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(CastTestVector({3, 6, 9, 12}))); } @@ -2308,9 +2603,7 @@ void TestAddN(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); DataVec output_data{{"my_addn", test->ConstructTensor(2)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(CastTestVector({5, 8}))); } @@ -2332,10 +2625,9 @@ TEST_F(OpConverterTest, ConvertAddN) { } TEST_F(OpConverterTest, ConvertQuantize) { - precision_mode_to_test_ = TrtPrecisionMode::INT8; { // FakeQuantWithMinMaxArgs attributes are empty, should fail. - Reset(); + Reset(TrtPrecisionMode::INT8); NodeDef node_def = MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"}); AddTestTensor("input", {1, 2, 3}); @@ -2346,7 +2638,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // FakeQuantWithMinMaxArgs ranges set via attributes, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f); @@ -2364,7 +2656,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // FakeQuantWithMinMaxVars ranges set via inputs, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2385,7 +2677,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // QuantizeAndDequantizeV2 ranges set via inputs, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2406,7 +2698,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // QuantizeAndDequantizeV2 Range inputs are tensors, should fail. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2424,7 +2716,7 @@ TEST_F(OpConverterTest, ConvertQuantize) { } { // QuantizeAndDequantizeV3 ranges set via inputs, ok. - Reset(); + Reset(TrtPrecisionMode::INT8); Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT); @@ -2477,9 +2769,7 @@ void TestConvertSquare(OpConverterTest* test) { // Engine outputs are converted to FP16 automatically if we set FP16 mode in // the builder. DataVec output_data{{"my_square", test->ConstructTensor(num_inputs)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); } @@ -2828,124 +3118,117 @@ TEST_F(OpConverterTest, ConvertExpandDims) { } } -TEST_F(OpConverterTest, ConvertSqueeze) { - { - // No attrs, should fail. - Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); - const NodeDef& node_def = squeeze.operation.node()->def(); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "Squeeze is only implemented for explicit dims, at my_squeeze"); - } +TEST_P(ParameterizedOpConverterTest, ConvertSqueeze) { + const auto& spec = GetParam(); + const TrtTestMode trt_mode = std::get<0>(spec); + const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch); + // Data type of TF input tensors + const DataType tf_dtype = std::get<1>(spec); + // Precision mode used for TensorRT engine + TrtPrecisionMode converter_precision = std::get<2>(spec); // Get the NodeDef for Squeeze. - auto get_squeeze_nodedef = [](std::vector axis) -> NodeDef { + auto get_squeeze_nodedef = [tf_dtype](std::vector axes) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - ops::Squeeze::Attrs squeeze_attrs; - squeeze_attrs.axis_ = gtl::ArraySlice(axis); // non-absl ok - auto squeeze = - ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); - return squeeze.operation.node()->def(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + if (!axes.empty()) { + ops::Squeeze::Attrs squeeze_attrs; + squeeze_attrs.axis_ = gtl::ArraySlice(axes); // non-absl ok + auto squeeze = + ops::Squeeze(s.WithOpName("my_squeeze"), input, squeeze_attrs); + return squeeze.operation.node()->def(); + } else { + auto squeeze = ops::Squeeze(s.WithOpName("my_squeeze"), input); + return squeeze.operation.node()->def(); + } }; - - { - // Input is weights, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({0}); - AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "The input \"input\" for Squeeze must be a tensor, at my_squeeze"); - } - { - // Squeeze batch dim, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({0}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "TensorRT does not allow manipulation of the " - "batch dimension, at my_squeeze"); - } - { - // Squeeze batch dim via negative axis, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({-4}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def, error::UNIMPLEMENTED, - "TensorRT does not allow manipulation of the " - "batch dimension, at my_squeeze"); - } - { - // Squeeze >= rank(input), should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({4}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Axis value of 4 is out of bounds, must be in range [-4, 4), at " - "my_squeeze"); - } - { - // Squeeze < -rank(input), should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({-5}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Axis value of -5 is out of bounds, must be in range [-4, 4), at " - "my_squeeze"); - } - { - // Squeeze an axis with size != 1, should fail. - Reset(); - NodeDef node_def = get_squeeze_nodedef({2}); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion( - node_def, error::INVALID_ARGUMENT, - "Dimension 2 with size 2 cannot be squeezed because it must be size 1, " - "at my_squeeze"); - } - - struct TestParams { - std::vector input_dims; - std::vector axis; - std::vector expected_output_dims; + std::vector test_params = { + TestParamBase{ + {1, 2, 1, 3}, // input dims + {}, // input partial dims + {2, 3}, // expected output dims + {}, // axis + trt_mode == TrtTestMode::kExplicitBatch + ? Status::OK() + : Status{error::UNIMPLEMENTED, + "Squeeze is not implemented for empty squeeze_dims, at " + "my_squeeze"}}, + TestParamBase{{1, 2, 1, 3}, + {}, + {2, 1, 3}, + {0}, + use_implicit_batch + ? Status{error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_squeeze"} + : Status::OK()}, + TestParamBase{{1, 2, 1, 3}, + {}, + {2, 1, 3}, + {-4}, + use_implicit_batch + ? Status{error::UNIMPLEMENTED, + "TensorRT does not allow manipulation of the " + "batch dimension, at my_squeeze"} + : Status::OK()}, + TestParamBase{ + {1, 1, 2, 3}, + {}, + {}, + {4}, + Status{error::INVALID_ARGUMENT, + "Axis value of 4 is out of bounds, must be in range [-4, 4), " + "at my_squeeze"}}, + TestParamBase{ + {1, 1, 2, 3}, + {}, + {}, + {-5}, + Status{error::INVALID_ARGUMENT, + "Axis value of -5 is out of bounds, must be in range [-4, 4), " + "at my_squeeze"}}, + TestParamBase{{1, 1, 2, 3}, {}, {1, 2, 3}, {1}}, + TestParamBase{{1, 1, 2, 3}, {}, {1, 2, 3}, {-3}}, + TestParamBase{{1, 2, 3, 1}, {}, {1, 2, 3}, {3}}, + TestParamBase{{1, 2, 3, 1}, {}, {1, 2, 3}, {-1}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {1, 3, 5}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {3, 1, 5}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {-1, -3, -5}}, + TestParamBase{{1, 1, 2, 1, 3, 1}, {}, {1, 2, 3}, {1, -3, 5}}, + TestParamBase{{1, 1, 6}, {}, {1, 6}, {1}}, + TestParamBase{{1, 6, 1}, {}, {1, 6}, {2}}, }; + auto squeeze_non_singleton = TestParamBase{ + {1, 1, 2, 3}, + {}, + {}, + {2}, + Status{error::INVALID_ARGUMENT, + "Dimension 2 with size 2 cannot be squeezed because it must be " + "size 1, at my_squeeze"}}; - // Ok. - std::vector ok_params = { - TestParams{{1, 2, 3}, {1}, {2, 3}}, - TestParams{{1, 2, 3}, {-3}, {2, 3}}, - TestParams{{2, 3, 1}, {3}, {2, 3}}, - TestParams{{2, 3, 1}, {-1}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {1, 3, 5}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {3, 1, 5}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {-1, -3, -5}, {2, 3}}, - TestParams{{1, 2, 1, 3, 1}, {1, -3, 5}, {2, 3}}, - TestParams{{1, 6}, {1}, {6}}, - TestParams{{6, 1}, {2}, {6}}, - }; - for (int i = 0; i < ok_params.size(); ++i) { - Reset(); - NodeDef node_def = get_squeeze_nodedef(ok_params[i].axis); - AddTestTensor("input", ok_params[i].input_dims); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_squeeze", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, - output.tensor()->getDimensions()); + if (trt_mode == TrtTestMode::kDynamicShape) { + // In this test we try to squeeze axis=2 which has size > 1. In dynamic + // shape mode the converter sees only -1, so it cannot catch this error. + squeeze_non_singleton.status = Status::OK(); // conversion status + squeeze_non_singleton.runtime_status = + errors::InvalidArgument("Negative number of dimensions -1"); + // Dynamic shape tests with partially known input shape + test_params.push_back(TestParamBase{{2, 1, 3}, {2, -1, 3}, {2, 3}, {1}}); + test_params.push_back(TestParamBase{{2, 1, 3}, {2, 1, -1}, {2, 3}, {1}}); + } + test_params.push_back(squeeze_non_singleton); - const DataVec input_data{{"input", AsTensor({1, 2, 3, 4, 5, 6})}}; - DataVec output_data{{"my_squeeze", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAre(1, 2, 3, 4, 5, 6)); + for (TestParamBase p : test_params) { + SCOPED_TRACE(p); + Reset(converter_precision, trt_mode); + NodeDef node_def = get_squeeze_nodedef(p.param); + AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode, + &p.partial_input_dims); + RunValidationAndConversion(node_def, p.status, "my_squeeze", + p.expected_output_dims); + InstantiateBuildAndRun(tf_dtype, "my_squeeze", this, p, {1, 2, 3, 4, 5, 6}, + ElementsAreArray({1, 2, 3, 4, 5, 6})); } } @@ -4776,10 +5059,8 @@ void TestConvertGather(OpConverterTest* test) { } DataVec output_data{ {"my_gather", test->ConstructTensor(expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32, - /*batch_size=*/expected_output_shape[0]); + test->BuildAndRun(input_data, &output_data, + /*batch_size=*/expected_output_shape[0]); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(converted_expected_output)); } @@ -4850,135 +5131,54 @@ TEST_F(OpConverterTest, ConvertGather) { TestConvertGather(this); } -TEST_F(OpConverterTest, ConvertUnary) { +template +NodeDef CreateUnaryOp() { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + return T(s.WithOpName("my_unary"), input).operation.node()->def(); +} + +TEST_P(ParameterizedOpConverterTest, ConvertUnary) { + const auto& spec = GetParam(); + const TrtTestMode trt_mode = std::get<0>(spec); + const DataType tf_dtype = std::get<1>(spec); + TrtPrecisionMode converter_precision = std::get<2>(spec); { // Input is weights, should fail. - Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto neg = ops::Neg(s.WithOpName("my_unary"), input); - const NodeDef& node_def = neg.operation.node()->def(); + Reset(converter_precision, trt_mode); + const NodeDef node_def = CreateUnaryOp(); AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Neg must be a tensor, at my_unary"); } - - // Get nodedef for unary layer. - auto get_unary_nodedef = [](string op_name) -> NodeDef { - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "Abs") { - auto unary = ops::Abs(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Acos") { - auto unary = ops::Acos(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Acosh") { - auto unary = ops::Acosh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Asin") { - auto unary = ops::Asin(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Asinh") { - auto unary = ops::Asinh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Atan") { - auto unary = ops::Atan(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Atanh") { - auto unary = ops::Atanh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Ceil") { - auto unary = ops::Ceil(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Cos") { - auto unary = ops::Cos(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Cosh") { - auto unary = ops::Cosh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Exp") { - auto unary = ops::Exp(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Floor") { - auto unary = ops::Floor(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Log") { - auto unary = ops::Log(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Neg") { - auto unary = ops::Neg(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Reciprocal") { - auto unary = ops::Reciprocal(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Rsqrt") { - auto unary = ops::Rsqrt(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sin") { - auto unary = ops::Sin(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sinh") { - auto unary = ops::Sinh(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Sqrt") { - auto unary = ops::Sqrt(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } else if (op_name == "Tan") { - auto unary = ops::Tan(s.WithOpName("my_unary"), input); - return unary.operation.node()->def(); - } - EXPECT_TRUE(false); - return NodeDef(); - }; - // Get expected output for unary layer. - auto get_unary_output = [](string op_name, float input) -> float { - if (op_name == "Abs") { - return std::abs(input); - } else if (op_name == "Acos") { - return std::acos(input); - } else if (op_name == "Acosh") { - return std::acosh(input); - } else if (op_name == "Asin") { - return std::asin(input); - } else if (op_name == "Asinh") { - return std::asinh(input); - } else if (op_name == "Atan") { - return std::atan(input); - } else if (op_name == "Atanh") { - return std::atanh(input); - } else if (op_name == "Ceil") { - return std::ceil(input); - } else if (op_name == "Cos") { - return std::cos(input); - } else if (op_name == "Cosh") { - return std::cosh(input); - } else if (op_name == "Exp") { - return std::exp(input); - } else if (op_name == "Floor") { - return std::floor(input); - } else if (op_name == "Log") { - return std::log(input); - } else if (op_name == "Neg") { - return -input; - } else if (op_name == "Reciprocal") { - return 1.0 / input; - } else if (op_name == "Rsqrt") { - return 1.0 / std::sqrt(input); - } else if (op_name == "Sin") { - return std::sin(input); - } else if (op_name == "Sinh") { - return std::sinh(input); - } else if (op_name == "Sqrt") { - return std::sqrt(input); - } else if (op_name == "Tan") { - return std::tan(input); - } - EXPECT_TRUE(false); - return 0; - }; - + using OpFunc = std::function; + using ValFunc = float (*)(float); + std::map> op_map; +#define ADD_OP(name, op, compute) \ + op_map[name] = \ + std::make_pair(CreateUnaryOp, static_cast(compute)) + ADD_OP("Abs", ops::Abs, std::abs); + ADD_OP("Acos", ops::Acos, std::acos); + ADD_OP("Acosh", ops::Acosh, std::acosh); + ADD_OP("Asin", ops::Asin, std::asin); + ADD_OP("Asinh", ops::Asinh, std::asinh); + ADD_OP("Atan", ops::Atan, std::atan); + ADD_OP("Atanh", ops::Atanh, std::atanh); + ADD_OP("Ceil", ops::Ceil, std::ceil); + ADD_OP("Cos", ops::Cos, std::cos); + ADD_OP("Cosh", ops::Cosh, std::cosh); + ADD_OP("Exp", ops::Exp, std::exp); + ADD_OP("Floor", ops::Floor, std::floor); + ADD_OP("Log", ops::Log, std::log); + ADD_OP("Neg", ops::Neg, [](float x) { return -x; }); + ADD_OP("Reciprocal", ops::Reciprocal, [](float x) { return 1.0f / x; }); + ADD_OP("Rsqrt", ops::Rsqrt, [](float x) { return 1.0f / std::sqrt(x); }); + ADD_OP("Sin", ops::Sin, std::sin); + ADD_OP("Sinh", ops::Sinh, std::sinh); + ADD_OP("Sqrt", ops::Sqrt, std::sqrt); + ADD_OP("Tan", ops::Tan, std::tan); +#undef ADD_OP // Get list of ops to test. std::vector ops_to_test; // Add all ops supported by ConvertUnary. @@ -4989,26 +5189,30 @@ TEST_F(OpConverterTest, ConvertUnary) { } // Add other unary ops to test. ops_to_test.push_back("Rsqrt"); - // Ok. + // Prepare test parameters + auto p = TestParamBase{ + {1, 1, 2, 3}, // input dims + {}, // input partial dims + {1, 1, 2, 3}, // expected output dims + }; for (const string& op_name : ops_to_test) { - Reset(); - NodeDef node_def = get_unary_nodedef(op_name); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); - - const std::vector input = {-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; - const DataVec input_data{{"input", AsTensor(input)}}; - DataVec output_data{{"my_unary", ConstructTensor(6)}}; - BuildAndRun(input_data, &output_data); - for (int i = 0; i < input.size(); ++i) { - const float expected_output = get_unary_output(op_name, input[i]); - EXPECT_THAT(GetSpanForData(output_data[0])[i], - NanSensitiveFloatNear(expected_output, 0.0001)); + SCOPED_TRACE(op_name); + Reset(converter_precision, trt_mode); + if (!op_map.count(op_name)) { + FAIL() << "Unary op test map does not contain op " << op_name; } + NodeDef node_def = op_map[op_name].first(); + + AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode); + RunValidationAndConversion(node_def, Status::OK(), "my_unary", + p.expected_output_dims); + + std::vector input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; + std::vector output; + std::transform(input_values.begin(), input_values.end(), + std::back_inserter(output), op_map[op_name].second); + InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values, + ArrayFloatNear(output, 0.0001, true)); } } @@ -5112,9 +5316,7 @@ void TestConvertConcat(OpConverterTest* test) { DataVec output_data{ {"my_concat", test->ConstructTensor(ok_params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -5279,9 +5481,7 @@ void TestConvertSplit(OpConverterTest* test) { // Verify output values are correct. const DataVec input_data{ {"value", test->AsTensor(ok_params[i].value)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5458,9 +5658,7 @@ void TestConvertUnpack(OpConverterTest* test) { // Verify output values are correct. const DataVec input_data{ {"value", test->AsTensor(ok_params[i].value)}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5629,9 +5827,7 @@ void TestConvertPack(OpConverterTest* test) { } DataVec output_data{{"my_pack", test->ConstructTensor( params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -5779,9 +5975,7 @@ void TestConvertArgMinMax(OpConverterTest* test) { DataVec output_data{ {"my_arg", test->ConstructTensor( params[i].expected_argmax_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "ArgMax") { EXPECT_THAT(GetSpanForData(output_data[0]), @@ -5880,9 +6074,7 @@ void TestConvertDepthSpaceShuffle( DataVec input_data{{"input", test->AsTensor(params[i].input_value)}}; DataVec output_data{{"my_shuffle", test->ConstructTensor( params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6158,9 +6350,7 @@ void TestConvertClipByValue(OpConverterTest* test) { DataVec input_data{{"t", test->AsTensor(params[i].input_value)}}; DataVec output_data{{"my_clip", test->ConstructTensor( params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6268,9 +6458,7 @@ void TestConvertSquaredDifference(OpConverterTest* test) { DataVec output_data{ {"my_squared_diff", test->ConstructTensor(params[i].expected_output.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); EXPECT_THAT(GetSpanForData(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6375,9 +6563,7 @@ void TestConvertResize(OpConverterTest* test) { {"my_resize", test->ConstructTensor( params[i].expected_nearest_output_values.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); if (node_def.op() == "ResizeBilinear") { ExpectArrayAlmostEqual(params[i].expected_bilinear_output_values, @@ -6477,9 +6663,7 @@ void TestConvertPad(OpConverterTest* test) { {"my_pad", test->ConstructTensor( params[i].expected_output_values.size())}}; - test->BuildAndRun( - input_data, &output_data, - dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + test->BuildAndRun(input_data, &output_data); ExpectArrayAlmostEqual(params[i].expected_output_values, GetSpanForData(output_data[0]), CType(1e-5)); } diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 5b97a8f1aa2..749335f1b09 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -371,6 +371,174 @@ string TensorPropertiesToString( }); } +// From the given list of input properties, returns the leading shape, which is +// the shape that determines the batch size of the operation. The leading shape +// is selected from the group of input shapes with the highest rank as follows: +// . If all of those shapes have non-negative values for the batch dimension, +// the leading shape is the one with the largest value for the batch +// dimension. +// . If some or all of those shapes have negative values for the batch +// dimension, and the rest of those shapes have 1 for the batch dimension, +// the leading shape is the first of those shapes with a negative value for +// the batch dimension. +// . Otherwise, we can't determine the leading shape for the operation and +// have to exclude the operation from TRT. +// +// Examples: +// case-1: a[1,3,4] + b[2,3,4] => leading shape [2,3,4] +// case-2: a[2,3,4] + b[scalar] => leading shape [2,3,4] +// case-3: a[-1,3,4] + b[1,3,4] => leading shape [-1,3,4] +// case-4: a[-1,3,4] + b[2,3,4] => no leading shape +// +// We have to return "no leading shape" for case-4 to exclude such operation +// from being translated for this reason: +// The actually input for "a" have to be in the shape of [2,3,4] for the +// operation to be valid. On the other hand, if we translate the operation +// to implicit batch mode, it will becomes a[3,4]+b[3,4] which is valid for +// any input shape of "a". +// +// This routine assumes the input program is valid. For example, we shouldn't +// see invalid operation like a[2,3,4] + b[3,3,4]. It also assumes the input +// properties is not empty and all input have known shapes. +// +// TODO(bixia): find a way to share this knowledge with the converter. +// TODO(bixia): investigate the use of symbolic shape analysis to improve +// segmentation, such as by requiring the dynamic dimensions to have the same +// negative value. +absl::optional FindLeadingShape( + absl::Span properties) { + DCHECK(!properties.empty()); + const TensorShapeProto* result; + int max_batch_dim_value; + auto choose_shape_with_higher_rank = [&](const TensorShapeProto* s) { + result = s; + max_batch_dim_value = s->dim_size() < 1 ? 1 : s->dim(0).size(); + }; + + DCHECK(!properties[0].shape().unknown_rank()); + choose_shape_with_higher_rank(&properties[0].shape()); + + for (const OpInfo::TensorProperties& p : properties.subspan(1)) { + DCHECK(!p.shape().unknown_rank()); + if (p.shape().dim_size() < result->dim_size()) continue; + + if (p.shape().dim_size() > result->dim_size()) { + choose_shape_with_higher_rank(&p.shape()); + continue; + } + + // Among the shapes with the same rank, choose the one with a dynamic batch + // size. If no shapes have a dynamic batch size, choose the one with the + // largest size. + if (result->dim_size() < 1) continue; + + if (p.shape().dim(0).size() < 0 || result->dim(0).size() < 0) { + if (p.shape().dim(0).size() < 0 && result->dim(0).size() >= 0) { + result = &p.shape(); + } else { + max_batch_dim_value = + std::max(max_batch_dim_value, p.shape().dim(0).size()); + } + + continue; + } + + if (p.shape().dim(0).size() > result->dim(0).size()) { + result = &p.shape(); + max_batch_dim_value = result->dim(0).size(); + } + } + + if (result->dim_size() > 0 && result->dim(0).size() < 0) { + // dynamic batch size + if (max_batch_dim_value <= 1) { + return result; + } else { + return absl::nullopt; + } + } + + return result; +} + +// Returns the inputs that are relevant to determinate the batch size of the +// operation. This routine handles the following cases: +// . Operations that support implicit boradcasting, such as operation mul. +// In this case, we need to inspect all the inputs in order to determine the +// batch size of the operation. +// . Special cases. Such as "Conv2DBackpropInput", "Conv3DBackpropInputV2". +// . The batch size of a operation is determined by the first input of the +// operation. +absl::Span GetInputsToDeterminateBatchSize( + const Node* node, const std::vector& all_inputs) { + // TODO(bixia): Find a way to share this knowledge with the converter. + static std::set broadcast_supporting_ops = { + // ops corresponding to ConvertBinary in the converter + "Add", + "AddV2", + "Mul", + "Sub" + "Div", + "FloorDiv", + "RealDiv", + "Minimum", + "Maximum", + "Pow", + // other ops that need to need GetTrtBroadcastShape to convert + "BiasAdd", + "SquaredDifference", + "BatchMatMul", + "BatchMatMulV2", + }; + const string& op = node->def().op(); + + if (op == "Conv2DBackpropInput" || op == "Conv3DBackpropInputV2") { + DCHECK_EQ(all_inputs.size(), 3); + return absl::MakeSpan(all_inputs).subspan(2, 1); + } + + if (broadcast_supporting_ops.count(op)) { + return absl::MakeSpan(all_inputs); + } + + // This is the common case for the operations that don't support implicit + // broadcasting: the first operand determines its batch size. All otherwise + // cases are handled before reaching here. + return absl::MakeSpan(all_inputs).subspan(0, 1); +} + +// Returns true if the operation we can remove the implicit batch of the +// operation. +// +// In particular, if the input shape has dynamic rank or the input shape rank +// is less than 2, we can't remove the implicit batch dimension and generate +// a new operation for TRT translation. +bool OperationCanBeTranslatedToImplicitBatch( + const grappler::GraphProperties* graph_properties, const Node* node) { + VLOG(3) << "process node " << node->name(); + if (node->num_inputs() == 0) return true; + if (!graph_properties || !graph_properties->HasInputProperties(node->name())) + return false; + + VLOG(3) << "input shapes " + << TensorPropertiesToString( + graph_properties->GetInputProperties(node->name())); + + const std::vector& all_input_properties = + graph_properties->GetInputProperties(node->name()); + absl::Span input_properties = + GetInputsToDeterminateBatchSize(node, all_input_properties); + if (absl::c_any_of(input_properties, [](const OpInfo::TensorProperties& p) { + return p.shape().unknown_rank(); + })) { + return false; + } + + absl::optional leading_shape = + FindLeadingShape(input_properties); + return leading_shape.has_value() && leading_shape.value()->dim_size() >= 2; +} + // Returns true if we can't be sure that the operand with the given properties // won't have negative values for non-batch dimensions. // @@ -467,6 +635,42 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } } +// Returns a batch size representation for a segment that only contains the +// given node. +ClusterBatchSize GetClusterBatchSizeForNode( + const grappler::GraphProperties* graph_properties, const Node* node, + bool use_implicit_batch) { + ClusterBatchSize cluster_batch_size; + if (!use_implicit_batch || !node || node->num_inputs() == 0) { + return cluster_batch_size; + } + + if (!graph_properties || + !graph_properties->HasInputProperties(node->name())) { + VLOG(3) << "doesn't have input property"; + return cluster_batch_size.SetBatchSizeValue(-1); + } + + const std::vector& input_properties = + graph_properties->GetInputProperties(node->name()); + absl::optional optional_leading_shape = + FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties)); + DCHECK(optional_leading_shape.has_value()); + const TensorShapeProto* leading_shape = optional_leading_shape.value(); + + DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2); + return cluster_batch_size.SetBatchSizeValue(leading_shape->dim(0).size()); +} + +void AddSegmentForNode(const grappler::GraphProperties* graph_properties, + std::vector>* segments, + SimpleNode* node, bool use_implicit_batch) { + segments->emplace_back( + node, GetClusterBatchSizeForNode( + graph_properties, node == nullptr ? nullptr : node->tf_node(), + use_implicit_batch)); +} + } // namespace Status SegmentGraph(const Graph* tf_graph, @@ -528,6 +732,12 @@ Status SegmentGraph(const Graph* tf_graph, }; if (options.exclude_node_list.count(node->name()) != 0) { exclude_node("excluded by segmenter option"); + } else if (options.use_implicit_batch && + !OperationCanBeTranslatedToImplicitBatch(graph_properties, + node->tf_node())) { + exclude_node( + "implicit batch mode requires input shape with at least two " + "dimensions"); } else if (!options.allow_dynamic_non_batch_dim && OperationHasDynamicNonBatchDimension(graph_properties, node->tf_node())) { @@ -548,7 +758,8 @@ Status SegmentGraph(const Graph* tf_graph, << "(Op name: " << node->name(); } } - node_segments.emplace_back(node); + AddSegmentForNode(graph_properties, &node_segments, node, + options.use_implicit_batch); } string msg = StrCat( "There are ", num_unsupported_ops, " ops of ", unsupported_ops.size(), @@ -581,18 +792,23 @@ Status SegmentGraph(const Graph* tf_graph, return true; }); for (const SimpleNode* node : order) { - // All output nodes of 'node' have been visited... + // All output nodes of 'node' have been visited. VLOG(3) << "Trying node " << node->name() << " id=" << node->id(); - // 'node' must be a TRT candidate... + // 'node' must be a TRT candidate. if (node_segments[node->id()].Value() == nullptr) { VLOG(3) << "... not a TRT candidate"; continue; } - // Contract output edges to combine 'node' with output - // nodes. Iterate since combining two nodes may unblock other - // combining. + // Contract output edges to combine 'node' with output nodes. Repeat this + // step until no output edges can be further contracted. This is because + // contracting an output edge may unblock new edges for contracting. + ClusterBatchSize expected_batch_size = + node_segments[node->id()].BatchSize(); + VLOG(3) << "batch size " << expected_batch_size; while (true) { std::set contract_edges; + // TODO(bixia): consider merging the loop to find the edges and the loop + // to contract the edges. for (const SimpleEdge* out_edge : node->out_edges()) { VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; @@ -600,14 +816,26 @@ Status SegmentGraph(const Graph* tf_graph, VLOG(3) << "... ... Control Edge, Skipping"; continue; } - // Out node must be TRT candidate... + // Out node must be a TRT candidate. if (node_segments[out_edge->dst()->id()].Value() == nullptr) { VLOG(3) << "... ... not a TRT candidate"; continue; } + // Out node must have compatible batch size. + ClusterBatchSize out_batch_size = + node_segments[out_edge->dst()->id()].BatchSize(); + ClusterBatchSize merged_batch_size = expected_batch_size; + if (!merged_batch_size.MergeIfCompatible(out_batch_size)) { + VLOG(3) << "... ... incompatible batch size " + << expected_batch_size.ToString() << " " + << out_batch_size.ToString(); + continue; + } if (CanContractEdge(out_edge, graph)) { - VLOG(3) << "... ... can contract"; + VLOG(3) << "... ... can contract. new batch size " + << merged_batch_size.ToString(); contract_edges.insert(out_edge); + expected_batch_size = merged_batch_size; } else { VLOG(3) << "... ... cannot contract, would form cycle"; } @@ -624,7 +852,8 @@ Status SegmentGraph(const Graph* tf_graph, VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " (" << src->id() << " <- " << dst->id(); - node_segments[src->id()].Merge(&node_segments[dst->id()]); + TF_RETURN_IF_ERROR( + node_segments[src->id()].Merge(&node_segments[dst->id()])); // Contracting the edge leaves disconnected graph edges. // Remove these from the graph and from 'contract_edges' so we @@ -638,6 +867,12 @@ Status SegmentGraph(const Graph* tf_graph, graph->RemoveEdge(r); } } + ClusterBatchSize actual_batch_size = + node_segments[node->id()].BatchSize(); + if (expected_batch_size != actual_batch_size) { + return errors::Internal( + "expected batch size is not the same as the actual batch size"); + } } } diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 68195addb03..2437481a9c4 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -369,6 +369,154 @@ TEST_F(SegmentTest, ExcludeReshapeWithDynamicNonBatchDimensionInOutput) { RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {}); } +TEST_F(SegmentTest, RankOneCannotUseImplicitBatch) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(TensorShape({3})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({3})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {}); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val); + auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + item.fetch.push_back("output-1"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-scalar", "output-0", "output-1"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {}); +} + +TEST_F(SegmentTest, TwoChainsDiffBatchSizes) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({5, 3})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-scalar"), 1.0f, {}); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, const_val); + auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, const_val); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + item.fetch.push_back("output-1"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-scalar", "output-0", "output-1"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + {{"output-0", "const-scalar"}}); +} + +TEST_F(SegmentTest, SameRankImplicitBroadcastingStaticBatchSize) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(TensorShape({2, 3, 1})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 3, 4})); + auto input_2_shape = ops::Placeholder::Shape(TensorShape({2, 3, 4})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto input_2 = + ops::Placeholder(s.WithOpName("input-2"), DT_FLOAT, input_2_shape); + auto multiple = ops::Mul(s.WithOpName("multiple"), input_2, input_2); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, multiple); + auto output_1 = ops::Add(s.WithOpName("output-1"), input_1, multiple); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + item.fetch.push_back("output-1"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"multiple", "output-0", "output-1"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + {all_nodes}); +} + +TEST_F(SegmentTest, SameRankImplicitBroadcastingDynamicBatchSize) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({1, 2})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {1, 1}); + auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-val", "add-0", "output-0"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, + {{"const-val", "add-0", "output-0"}}); +} + +TEST_F(SegmentTest, IncompatibleBatchSizes) { + Scope s = Scope::NewRootScope(); + auto input_0_shape = ops::Placeholder::Shape(PartialTensorShape({-1, 2})); + auto input_1_shape = ops::Placeholder::Shape(TensorShape({2, 2})); + auto input_0 = + ops::Placeholder(s.WithOpName("input-0"), DT_FLOAT, input_0_shape); + auto input_1 = + ops::Placeholder(s.WithOpName("input-1"), DT_FLOAT, input_1_shape); + auto const_val = ops::Const(s.WithOpName("const-val"), 1.0f, {2, 2}); + auto add_0 = ops::Add(s.WithOpName("add-0"), input_0, const_val); + auto output_0 = ops::Add(s.WithOpName("output-0"), input_0, add_0); + + grappler::GrapplerItem item; + item.fetch.push_back("output-0"); + TF_EXPECT_OK(s.ToGraphDef(&item.graph)); + + grappler::GraphProperties static_graph_properties(item); + TF_EXPECT_OK(static_graph_properties.InferStatically(true)); + + Graph g(OpRegistry::Global()); + TF_CHECK_OK( + ConvertGraphDefToGraph(GraphConstructorOptions(), item.graph, &g)); + + const std::set all_nodes = {"const-val", "add-0", "output-0"}; + EnableImplicitBatchModeForStaticEngine(); + RunTest(&g, &static_graph_properties, all_nodes, all_nodes, all_nodes, {}); +} } // namespace test } // namespace segment } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index 6458ae692fd..70e83c12fca 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -16,51 +16,192 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + namespace tensorflow { namespace tensorrt { namespace segment { -// Union-Find data structure. -// Each cluster has an associated value; when merging clusters we can control -// which value becomes the representative of the merged clusters. Values must be -// copyable. +// ClusterBatchSize is a data structure to record the batch size we have seen +// for a cluster during segmentation. +// +// When constructing clusters for implicit batch mode, we support the +// with both dynamic batch size and static batch size. We restrict nodes inside +// a cluster to either have dynamic batch size or have the same value for static +// batch size. For this reason, we use a field has_dynamic_batch_value_ to keep +// track of whether the cluster has any node with dynamic batch size. We use +// field static_batch_value_ to keep track of whether the cluster has any node +// with static batch size and what the value of the static batch size, if any. +// Examples: +// cluster: a = a1[1,3] + a1[1,3] +// ClusterBatchSize: has_dynamic_batch_size_ = false +// static_batch_value_ = {has value, 1} +// +// cluster: b = b1[-1,3] + b2[-1, 3] +// ClusterBatchSize: has_dynamic_batch_size_ = true +// static_batch_value_ = {has no value} +// +// cluster: a = a1[1,3] + a1[1,3]; b = b1[-1,3] + b2[-1, 3] +// ClusterBatchSize: has_dynamic_batch_size_ = true +// static_batch_value_ = {has value, 1} +// +// When constructing cluster for explicit batch mode, all ClusterBatchSize is +// irrelevant. +// +// +absl::optional static_batch_value_; +class ClusterBatchSize { + public: + ClusterBatchSize() + : has_dynamic_batch_value_(false), static_batch_value_(absl::nullopt) {} + + bool operator==(const ClusterBatchSize& b) { + return HasDynamicBatchValue() == b.HasDynamicBatchValue() && + static_batch_value_ == b.static_batch_value_; + } + + bool operator!=(const ClusterBatchSize& b) { return !(*this == b); } + + int GetStaticBatchValue() const { + DCHECK(HasStaticBatchValue()); + return static_batch_value_.value(); + } + + // Sets the batch size value assuming that the object doesn't have a batch + // size value yet: + // a non-negative input value representing a known batch size. + // a negative input value representing a dynamic batch size. + ClusterBatchSize SetBatchSizeValue(int value) { + if (value < 0) { + has_dynamic_batch_value_ = true; + return *this; + } + static_batch_value_ = value; + return *this; + } + + bool MergeIfCompatible(const ClusterBatchSize& b) { + bool is_compatible = MergeIfCompatible(b.static_batch_value_); + if (!is_compatible) return false; + + if (!HasDynamicBatchValue() && b.HasDynamicBatchValue()) { + has_dynamic_batch_value_ = true; + } + + return true; + } + + // Returns a string for the batch size value. If the object has a static + // batch size value, return a string for the value. If the object has a + // dynamic size value, return -1. Otherwise, returns -2 to represent that + // a batch size hasn't been set yet. + string ToString() const { + string s; + absl::StrAppendFormat(&s, "batch_size=(%d,%d,", HasDynamicBatchValue(), + HasStaticBatchValue()); + if (HasStaticBatchValue()) { + absl::StrAppendFormat(&s, "%d", GetStaticBatchValue()); + } + absl::StrAppend(&s, ")"); + return s; + } + + private: + bool HasStaticBatchValue() const { return static_batch_value_.has_value(); } + bool HasDynamicBatchValue() const { return has_dynamic_batch_value_; } + + private: + bool MergeIfCompatible(const absl::optional& b) { + bool is_compatible = !HasStaticBatchValue() || !b.has_value() || + GetStaticBatchValue() == b.value(); + if (!is_compatible) { + return false; + } + if (!HasStaticBatchValue() && b.has_value()) { + static_batch_value_ = b; + } + return true; + } + + private: + // To track whether the cluster has any node with dynamic batch size. + bool has_dynamic_batch_value_; + // To track whether the cluster has any node with static batch size, and the + // unique value for static batch size. + absl::optional static_batch_value_; +}; + +inline std::ostream& operator<<(std::ostream& os, + const ClusterBatchSize& batch_size) { + return os << batch_size.ToString(); +} + +// Represents a disjoint set of copyable values with type T. We use this data +// structure to construct clusters for TRTEngineOp. As such, this data structure +// has a field to record the batch size for the current cluster and merges the +// corresponding batch sizes when merging two clusters. Most of the methods in +// this class are side-effecting as they also compress the path from the object +// to the parent of its containing set. template class UnionFind { public: UnionFind() : size_(1), parent_(nullptr) {} - explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {} + explicit UnionFind(const T& v, ClusterBatchSize batch_size) + : size_(1), + cluster_batch_size_(batch_size), + parent_(nullptr), + value_(v) {} - // Returns the number of elements in a cluster. + // Returns the number of elements in the cluster and compresses the path from + // this object to the root of the cluster. int Size() { return FindRoot()->size_; } - // Merges this cluster with 'other'. This cluster's value becomes - // the value of the merged cluster; the value of 'other' is ignored. - void Merge(UnionFind* other); + // Returns the batch size of the cluster and compress the path from this + // object to the root object. + ClusterBatchSize BatchSize() { return FindRoot()->cluster_batch_size_; } - // Each cluster has an associated value. Retrieves the value associated - // with this cluster. + // Merges this cluster with 'other'. This cluster's size_ is updated to + // the size of the merged cluster; the size_ of 'other' becomes inaccessible + // as only the size_ of the root object is accessible. + Status Merge(UnionFind* other); + + // Retrieves the value for the root of the cluster. T& ParentValue() { return FindRoot()->value_; } - // Get the original value of this node. + // Returns the value for the object. T& Value() { return value_; } private: - // Finds the root element of the cluster. Performs path compression. + // Returns the root object for the cluster and compresses the path from this + // object to the root object. UnionFind* FindRoot(); int size_; + ClusterBatchSize cluster_batch_size_; UnionFind* parent_; T value_; }; template -void UnionFind::Merge(UnionFind* other) { +Status UnionFind::Merge(UnionFind* other) { UnionFind* a = FindRoot(); UnionFind* b = other->FindRoot(); - if (a == b) return; + if (a == b) return Status::OK(); + ClusterBatchSize batch_size = a->cluster_batch_size_; + bool merged = batch_size.MergeIfCompatible(other->cluster_batch_size_); + if (!merged) { + return errors::Internal("trying to merge incompatible cluster."); + } + + a->cluster_batch_size_ = batch_size; b->parent_ = a; a->size_ += b->size_; + return Status::OK(); } template @@ -76,4 +217,7 @@ UnionFind* UnionFind::FindRoot() { } // namespace tensorrt } // namespace tensorflow +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + #endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index a5332385994..55341c0a01f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -81,7 +81,7 @@ tf_portable_proto_library( name = "portable_tf2xla_proto", config_string = "allow_all:true", header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"], - portable_deps = ["//tensorflow/core:portable_proto_lib_full_runtime"], + portable_deps = ["//tensorflow/core:portable_proto_lib"], proto_deps = [ ":tf2xla_proto", "//tensorflow/core:protos_all", @@ -182,6 +182,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", ], ) @@ -703,12 +704,8 @@ cc_library( deps = [ "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", - "//tensorflow/compiler/mlir/tensorflow:device_util", - "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu", - "@com_google_absl//absl/container:flat_hash_set", + "//tensorflow/core:lib", "@llvm-project//llvm:support", ], alwayslink = 1, diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 4780bd7455e..bfdfe38305b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -103,6 +103,7 @@ tf_kernel_library( "spacetodepth_op.cc", "sparse_to_dense_op.cc", "split_op.cc", + "spmd_manual_sharding_ops.cc", "stack_ops.cc", "stateful_random_ops.cc", "stateless_random_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index bb2c0d9ddb8..5dbc083368c 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -28,6 +28,15 @@ limitations under the License. namespace tensorflow { namespace { +absl::InlinedVector SliceVector(xla::XlaOp input, int64 rank) { + absl::InlinedVector scalar_indices; + scalar_indices.reserve(rank); + for (int i = 0; i < rank; i++) + scalar_indices.push_back( + xla::Reshape(xla::Slice(input, {i}, {i + 1}, {1}), {})); + return scalar_indices; +} + class DynamicUpdateSliceOp : public XlaOpKernel { public: explicit DynamicUpdateSliceOp(OpKernelConstruction* context) @@ -41,21 +50,23 @@ class DynamicUpdateSliceOp : public XlaOpKernel { const TensorShape update_shape = ctx->InputShape("update"); const TensorShape index_shape = ctx->InputShape("indices"); + int64 rank = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsVector(index_shape) && - index_shape.num_elements() == input_shape.dims(), + index_shape.num_elements() == rank, errors::InvalidArgument("index must be a vector with length equal to " "the number of input dimensions")); OP_REQUIRES( - ctx, input_shape.dims() == update_shape.dims(), + ctx, rank == update_shape.dims(), errors::InvalidArgument("input and update must have the same rank," " input shape is ", input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); + xla::XlaOp indices = ctx->Input("indices"); xla::XlaOp result = xla::DynamicUpdateSlice( - ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); + ctx->Input("input"), ctx->Input("update"), SliceVector(indices, rank)); ctx->SetOutput(0, result); } }; @@ -76,17 +87,18 @@ class DynamicSliceOp : public XlaOpKernel { const TensorShape start_indices_shape = ctx->InputShape("start_indices"); const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + int64 rank = input_shape.dims(); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(start_indices_shape) && - start_indices_shape.num_elements() == input_shape.dims(), + start_indices_shape.num_elements() == rank, errors::InvalidArgument( "start_indices must be a vector with length equal to " "input rank, but input rank is ", - input_shape.dims(), " and start_indices has shape ", + rank, " and start_indices has shape ", start_indices_shape.DebugString())); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(size_indices_shape) && - size_indices_shape.num_elements() == input_shape.dims(), + size_indices_shape.num_elements() == rank, errors::InvalidArgument( "size_indices must be a vector with length equal to " "input rank, but input rank is ", @@ -96,8 +108,10 @@ class DynamicSliceOp : public XlaOpKernel { std::vector size_indices; OP_REQUIRES_OK( ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + + xla::XlaOp start_indices = ctx->Input("start_indices"); xla::XlaOp result = xla::DynamicSlice( - ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->Input("input"), SliceVector(start_indices, rank), size_indices); ctx->SetOutput(0, result); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 17d0b87edda..7f274c6b00f 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -42,19 +42,17 @@ class SliceOp : public XlaOpKernel { const TensorShape begin_tensor_shape = ctx->InputShape(1); const TensorShape size_tensor_shape = ctx->InputShape(2); + const int input_dims = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsVector(begin_tensor_shape) && TensorShapeUtils::IsVector(size_tensor_shape) && - begin_tensor_shape.num_elements() == input_shape.dims() && - size_tensor_shape.num_elements() == input_shape.dims(), + begin_tensor_shape.num_elements() == input_dims && + size_tensor_shape.num_elements() == input_dims, errors::InvalidArgument( "Expected begin and size arguments to be 1-D tensors of size ", - input_shape.dims(), ", but got shapes ", - begin_tensor_shape.DebugString(), " and ", - size_tensor_shape.DebugString(), " instead.")); - - const int input_dims = input_shape.dims(); + input_dims, ", but got shapes ", begin_tensor_shape.DebugString(), + " and ", size_tensor_shape.DebugString(), " instead.")); std::vector begin; std::vector size; @@ -129,7 +127,15 @@ class SliceOp : public XlaOpKernel { input_shape.dim_size(i), "], but ", "got ", size[i])); } - ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size)); + + absl::InlinedVector scalar_indices; + scalar_indices.reserve(input_dims); + xla::XlaOp begin = ctx->Input("begin"); + for (int i = 0; i < input_dims; i++) + scalar_indices.push_back( + xla::Reshape(xla::Slice(begin, {i}, {i + 1}, {1}), {})); + + ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), scalar_indices, size)); } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc new file mode 100644 index 00000000000..cd28fe8fa3f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaSpmdFullToShardShapeOp : public XlaOpKernel { + public: + explicit XlaSpmdFullToShardShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_)); + } + + ~XlaSpmdFullToShardShapeOp() override = default; + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + auto input_shape_or = ctx->InputXlaShape(0); + OP_REQUIRES_OK(ctx, input_shape_or.status()); + xla::OpSharding sharding; + if (!sharding.ParseFromString(manual_sharding_str_)) { + OP_REQUIRES_OK(ctx, + xla::InvalidArgument("manual_sharding attribute was not a " + "valid encoded xla::OpSharding " + "proto.")); + } + auto output_shape = input_shape_or.ValueOrDie(); + int64 rank = output_shape.rank(); + if (sharding.type() == xla::OpSharding::OTHER) { + for (int64 i = 0; i < rank; ++i) { + int64 partitions_i = sharding.tile_assignment_dimensions(i); + if (partitions_i == 1) continue; + int64 dim_size = + xla::CeilOfRatio(output_shape.dimensions(i), partitions_i); + output_shape.set_dimensions(i, dim_size); + } + } + xla::XlaOp input_annotation; + { + // Annotate the full-shape input with the manual sharding. + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + sharding); + input_annotation = + xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding", + {input}, input_shape_or.ValueOrDie()); + } + + { + // Annotate the shard-shape output with replicated sharding, so that the + // partitioner will leave it as is. + xla::OpSharding replicated; + replicated.set_type(xla::OpSharding::REPLICATED); + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + replicated); + auto output = xla::CustomCall(ctx->builder(), + /*call_target_name=*/"SPMDFullToShardShape", + {input_annotation}, output_shape); + ctx->SetOutput(0, output); + } + } + + private: + string manual_sharding_str_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdFullToShardShapeOp); +}; + +class XlaSpmdShardToFullShapeOp : public XlaOpKernel { + public: + explicit XlaSpmdShardToFullShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_shape", &full_shape_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("manual_sharding", &manual_sharding_str_)); + } + + ~XlaSpmdShardToFullShapeOp() override = default; + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + auto input_shape_or = ctx->InputXlaShape(0); + OP_REQUIRES_OK(ctx, input_shape_or.status()); + auto output_shape = TensorShapeToXLAShape( + input_shape_or.ValueOrDie().element_type(), full_shape_); + + xla::OpSharding sharding; + if (!sharding.ParseFromString(manual_sharding_str_)) { + OP_REQUIRES_OK(ctx, + xla::InvalidArgument("manual_sharding attribute was not a " + "valid encoded xla::OpSharding " + "proto.")); + } + xla::XlaOp input_annotation; + { + // Annotate the shard-shape input with replicated sharding, so that the + // partitioner will leave it as is. + xla::OpSharding replicated; + replicated.set_type(xla::OpSharding::REPLICATED); + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + replicated); + input_annotation = + xla::CustomCall(ctx->builder(), /*call_target_name=*/"Sharding", + {input}, input_shape_or.ValueOrDie()); + } + + { + // Annotate the full-shape output with the manual sharding. + xla::XlaScopedShardingAssignment assign_sharding(ctx->builder(), + sharding); + ctx->SetOutput( + 0, xla::CustomCall(ctx->builder(), + /*call_target_name=*/"SPMDShardToFullShape", + {input_annotation}, output_shape)); + } + } + + private: + TensorShape full_shape_; + string manual_sharding_str_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaSpmdShardToFullShapeOp); +}; + +REGISTER_XLA_OP(Name("XlaSpmdFullToShardShape"), XlaSpmdFullToShardShapeOp); +REGISTER_XLA_OP(Name("XlaSpmdShardToFullShape"), XlaSpmdShardToFullShapeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 6d0d569724f..c398e5f129e 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -18,10 +18,18 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { +auto* mlir_bridge_gauge_v1 = monitoring::Gauge::New( + "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v1", + "Tracks usage of the MLIR-based TF2XLA bridge among TF1 models"); +auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( + "/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2", + "Tracks usage of the MLIR-based TF2XLA bridge among TF2 models"); + // This runs the first phase of the "bridge", transforming the graph in a form // that can be executed with delegation of some computations to an accelerator. // This builds on the model of XLA where a subset of the graph is encapsulated @@ -31,11 +39,13 @@ namespace tensorflow { Status MlirBridgePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module) { if (!config_proto.experimental().enable_mlir_bridge()) { - VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled"; + VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled"; + mlir_bridge_gauge_v2->GetCell()->Set(false); return Status::OK(); } - VLOG(1) << "Running MLIR Bridge Pass"; + VLOG(0) << "Running MLIR TPU Bridge"; + mlir_bridge_gauge_v2->GetCell()->Set(true); TF_RETURN_IF_ERROR( mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1))); @@ -47,11 +57,13 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, if (options.is_function_graph) return Status::OK(); if (!options.session_options->config.experimental().enable_mlir_bridge()) { - VLOG(1) << "Skipping MLIR Bridge V1 Compat Pass, session flag not enabled"; + VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled"; + mlir_bridge_gauge_v1->GetCell()->Set(false); return Status::OK(); } - VLOG(1) << "Running MLIR Bridge V1 Compat Pass"; + VLOG(0) << "Running MLIR TPU Bridge V1 Compat"; + mlir_bridge_gauge_v1->GetCell()->Set(true); TF_RETURN_IF_ERROR( mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1))); diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index daf261fa5d8..43793be56a7 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -95,6 +96,7 @@ static void RegisterDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); + mlir::registerDialect(); return true; }(); (void)init_once; diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index c0bf423a644..862da1f3f95 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -648,6 +648,62 @@ This op has better TPU performance since it doesn't have explicitly reshape and transpose operations as tf.einsum does. )doc"); +REGISTER_OP("XlaSpmdFullToShardShape") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .Attr("manual_sharding: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + auto input_handle = c->input(0); + if (!c->RankKnown(input_handle)) { + return shape_inference::UnknownShape(c); + } + string sharding_attr; + TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr)); + std::vector dims; + for (int64 i = 0; i < c->Rank(input_handle); ++i) { + auto dim = c->Value(c->Dim(input_handle, i)); + xla::OpSharding sharding; + sharding.ParseFromString(sharding_attr); + int64 partitions_i = sharding.tile_assignment_dimensions(i); + if (dim != shape_inference::InferenceContext::kUnknownDim && + sharding.type() == xla::OpSharding::OTHER && partitions_i != 1) { + dim = (dim + partitions_i - 1) / partitions_i; + } + dims.push_back(c->MakeDim(dim)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + }) + .Doc(R"doc( +An op used by XLA SPMD partitioner to switch from automatic partitioning to +manual partitioning. It annotates the input (full-shape, to be automatically +partitioned) with the same sharding used by manual partitioning, and outputs a +shard-shaped tensor to be consumed by later manually-partitioned ops. If the +shape is not evenly partitionable, the padding region will be masked with 0s. +)doc"); + +REGISTER_OP("XlaSpmdShardToFullShape") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .Attr("manual_sharding: string") + .Attr("full_shape: shape") + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("full_shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +An op used by XLA SPMD partitioner to switch from manual partitioning to +automatic partitioning. It converts the shard-shaped, manually partitioned input +into full-shaped tensor to be partitioned automatically with the same sharding +used by manual partitioning. +)doc"); + REGISTER_OP("XlaSharding") .Input("input: T") .Output("output: T") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 0df61da57a3..c59c47e92fb 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -418,6 +418,26 @@ def _sharding_grad(op, grad): return [grad] +spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape +spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape + + +@ops.RegisterGradient("XlaSpmdFullToShardShape") +def _spmd_full_to_shard_shape_grad(op, grad): + s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( + grad, + manual_sharding=op.get_attr("manual_sharding"), + full_shape=op.inputs[0].shape.as_list()) + return [s2f] + + +@ops.RegisterGradient("XlaSpmdShardToFullShape") +def _spmd_shard_to_full_shape_grad(op, grad): + f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( + grad, manual_sharding=op.get_attr("manual_sharding")) + return [f2s] + + sort = gen_xla_ops.xla_sort key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 0aa139ce4f0..49f108ed6c8 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -121,6 +121,9 @@ xla::StatusOr> XlaExpression::ResolveConstant( handle().builder()->IsConstant(handle())); if (!is_constant) return {absl::nullopt}; + if (!client) + return errors::InvalidArgument("client is required to resolve constant"); + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, handle().builder()->BuildConstantSubGraph( handle(), dynamic_dimension_is_minus_one)); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a394de1a9e8..2c6edf5389e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -175,8 +175,9 @@ Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal) { XlaExpression e = InputExpression(index); + auto* client = compiler() ? compiler()->client() : nullptr; xla::StatusOr> constant_or_status = - e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_); + e.ResolveConstant(client, dynamic_dimension_is_minus_one_); if (!constant_or_status.ok()) { Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1350f9e3e0b..45f49cee328 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -17,7 +17,6 @@ package_group( "//tensorflow/compiler/...", "//tensorflow/python/tpu/...", "//third_party/py/jax/...", - "//third_party/tf_runtime/tools/tf_kernel_gen/...", ], ) @@ -332,6 +331,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index cd52e2f5e45..404f9eb7519 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -70,6 +70,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_num_partitions( return *this; } +ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning( + bool use_spmd_partitioning) { + use_spmd_partitioning_ = use_spmd_partitioning; + return *this; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment( const DeviceAssignment& device_assignment) { device_assignment_ = device_assignment; diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 360ad0260df..9a7fdd974b1 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -77,6 +77,11 @@ class ExecutableBuildOptions { int num_partitions() const { return num_partitions_; } ExecutableBuildOptions& set_num_partitions(int num_partitions); + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning() const { return use_spmd_partitioning_; } + ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning); + // If set, this specifies a static device assignment for the computation. // Otherwise, the computation will be compiled generically and can be run with // any device assignment compatible with the computation's replica and @@ -104,6 +109,7 @@ class ExecutableBuildOptions { se::DeviceMemoryAllocator* device_allocator_ = nullptr; int num_replicas_ = 1; int num_partitions_ = 1; + bool use_spmd_partitioning_ = false; absl::optional device_assignment_; bool alias_passthrough_params_ = false; }; diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 32796dd8d70..9b8156efe5b 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -298,6 +298,15 @@ XLA_TEST_F(MathTest, SqrtSixValues) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, CbrtSixValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {8.0, 1.0, 4096.0, -64.0, 1.728, 1331}); + Cbrt(x); + + std::vector expected = {2, 1, 16, -4, 1.2, 11}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.001)); +} + XLA_TEST_F(MathTest, SinhSmallValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {1e-3, 1e-5, 1e-7, 1e-9, 1e-11}); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7de4cd4b3c7..a4e5b936153 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -860,34 +860,10 @@ XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index, }); } -XlaOp XlaBuilder::DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, - GetShapePtr(start_indices)); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferDynamicSliceShape( - *operand_shape, {*start_indices_shape}, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - for (int64 size : slice_sizes) { - instr.add_dynamic_slice_sizes(size); - } - - return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, - {operand, start_indices}); - }); -} - XlaOp XlaBuilder::DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); std::vector start_indices_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes, @@ -898,43 +874,28 @@ XlaOp XlaBuilder::DynamicSlice(XlaOp operand, TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( *operand_shape, start_indices_shapes, slice_sizes)); - *instr.mutable_shape() = shape.ToProto(); - - for (int64 size : slice_sizes) { - instr.add_dynamic_slice_sizes(size); - } - - std::vector operands = {operand}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); + return DynamicSliceInternal(shape, operand, start_indices, slice_sizes); }); } -XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, - XlaOp start_indices) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; +StatusOr XlaBuilder::DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); - TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape, - GetShapePtr(start_indices)); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferDynamicUpdateSliceShape( - *operand_shape, *update_shape, {*start_indices_shape})); - *instr.mutable_shape() = shape.ToProto(); + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } - return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - {operand, update, start_indices}); - }); + std::vector operands = {operand}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands); } XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update)); std::vector start_indices_shape_ptrs; @@ -946,15 +907,22 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferDynamicUpdateSliceShape( *operand_shape, *update_shape, start_indices_shapes)); - *instr.mutable_shape() = shape.ToProto(); - - std::vector operands = {operand, update}; - operands.insert(operands.end(), start_indices.begin(), start_indices.end()); - return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, - operands); + return DynamicUpdateSliceInternal(shape, operand, update, start_indices); }); } +StatusOr XlaBuilder::DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + std::vector operands = {operand, update}; + operands.insert(operands.end(), start_indices.begin(), start_indices.end()); + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + operands); +} + XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1301,7 +1269,6 @@ XlaOp XlaBuilder::ConvGeneralDilated( int64 feature_group_count, int64 batch_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); TF_RETURN_IF_ERROR( @@ -1314,30 +1281,45 @@ XlaOp XlaBuilder::ConvGeneralDilated( window_dimensions[i] = rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } - TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + + TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferConvolveShape( - *lhs_shape, *rhs_shape, feature_group_count, - batch_group_count, instr.window(), dimension_numbers)); - *instr.mutable_shape() = shape.ToProto(); - - *instr.mutable_convolution_dimension_numbers() = dimension_numbers; - instr.set_feature_group_count(feature_group_count); - instr.set_batch_group_count(batch_group_count); - - if (precision_config != nullptr) { - *instr.mutable_precision_config() = *precision_config; - } - - return AddInstruction(std::move(instr), HloOpcode::kConvolution, - {lhs, rhs}); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferConvolveShape( + *lhs_shape, *rhs_shape, feature_group_count, + batch_group_count, window, dimension_numbers)); + return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides, + padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count, + batch_group_count, precision_config); }); } +StatusOr XlaBuilder::ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + *instr.mutable_window() = window; + *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + instr.set_feature_group_count(feature_group_count); + instr.set_batch_group_count(batch_group_count); + + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; + } + + return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); +} + XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type, const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1792,8 +1774,6 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, absl::Span parameters, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - // Check the number of parameters per RNG distribution. switch (distribution) { case RandomDistribution::RNG_NORMAL: @@ -1809,14 +1789,20 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - *instr.mutable_shape() = shape.ToProto(); - - instr.set_distribution(distribution); - - return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); + return RngOpInternal(distribution, parameters, shape); }); } +StatusOr XlaBuilder::RngOpInternal(RandomDistribution distribution, + absl::Span parameters, + const Shape& shape) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_distribution(distribution); + + return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); +} + XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) { return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); } @@ -2199,6 +2185,39 @@ XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, }); } +XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferAllGatherShape( + *operand_shape, all_gather_dimension, shard_count)); + if (layout) { + *inferred_shape.mutable_layout() = *layout; + instr.set_constrain_layout(true); + } + *instr.mutable_shape() = inferred_shape.ToProto(); + + instr.add_dimensions(all_gather_dimension); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } + + TF_ASSIGN_OR_RETURN( + auto all_gather, + AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand})); + return all_gather; + }); +} + XlaOp XlaBuilder::CrossReplicaSum( XlaOp operand, absl::Span replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -3101,20 +3120,11 @@ XlaOp SliceInDim(const XlaOp operand, int64 start_index, int64 limit_index, stride, dimno); } -XlaOp DynamicSlice(const XlaOp operand, const XlaOp start_indices, - absl::Span slice_sizes) { - return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); -} XlaOp DynamicSlice(const XlaOp operand, absl::Span start_indices, absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } -XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, - const XlaOp start_indices) { - return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); -} - XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update, absl::Span start_indices) { return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); @@ -3466,6 +3476,16 @@ XlaOp ReduceWindowWithGeneralPadding( base_dilations, window_dilations, padding); } +XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout) { + return operand.builder()->AllGather(operand, all_gather_dimension, + shard_count, replica_groups, channel_id, + layout); +} + XlaOp CrossReplicaSum(const XlaOp operand, absl::Span replica_groups) { return operand.builder()->CrossReplicaSum(operand, replica_groups); @@ -3571,6 +3591,9 @@ XlaOp Imag(const XlaOp operand) { XlaOp Sqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand); } +XlaOp Cbrt(const XlaOp operand) { + return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand); +} XlaOp Rsqrt(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 64424b9dd3c..b631514248c 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -421,16 +421,17 @@ class XlaBuilder { virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - ABSL_DEPRECATED("Use span-of-indices form instead") - XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); + virtual StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); - ABSL_DEPRECATED("Use span-of-indices form instead") - XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices); XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); + virtual StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); virtual StatusOr ConcatInDimInternal(const Shape& shape, @@ -491,6 +492,16 @@ class XlaBuilder { int64 batch_group_count = 1, const PrecisionConfig* precision_config = nullptr); + virtual StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, int64 batch_group_count, + const PrecisionConfig* precision_config); + XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span fft_length); @@ -549,6 +560,12 @@ class XlaBuilder { XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); + XlaOp AllGather( + XlaOp operand, int64 all_gather_dimension, int64 shard_count, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt, + const absl::optional& layout = absl::nullopt); + XlaOp AllReduce( XlaOp operand, const XlaComputation& computation, absl::Span replica_groups = {}, @@ -707,6 +724,10 @@ class XlaBuilder { XlaOp RngOp(RandomDistribution distribution, absl::Span parameters, const Shape& shape); + virtual StatusOr RngOpInternal(RandomDistribution distribution, + absl::Span parameters, + const Shape& shape); + virtual StatusOr InDimBroadcast( const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions); @@ -838,14 +859,10 @@ class XlaBuilder { friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - friend XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); friend XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); - friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - XlaOp start_indices); friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); @@ -988,6 +1005,11 @@ class XlaBuilder { absl::Span> padding); friend XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups); + friend XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, + int64 shard_count, + absl::Span replica_groups, + const absl::optional& channel_id, + const absl::optional& layout); friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, absl::Span replica_groups, const absl::optional& channel_id, @@ -1030,6 +1052,7 @@ class XlaBuilder { friend XlaOp Imag(XlaOp operand); friend XlaOp Sqrt(XlaOp operand); friend XlaOp Rsqrt(XlaOp operand); + friend XlaOp Cbrt(XlaOp operand); friend XlaOp Pow(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); friend XlaOp IsFinite(XlaOp operand); @@ -1412,10 +1435,6 @@ XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, absl::Span slice_sizes); -ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, - absl::Span slice_sizes); - // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. // The shape of 'update' determines the shape of the slice of 'operand' @@ -1436,9 +1455,6 @@ XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, absl::Span start_indices); -ABSL_DEPRECATED("Use span-of-indices form instead") -XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp start_indices); - // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, @@ -1766,6 +1782,11 @@ XlaOp ReduceWindowWithGeneralPadding( XlaOp CrossReplicaSum(XlaOp operand, absl::Span replica_groups = {}); +XlaOp AllGather(XlaOp operand, int64 all_gather_dimension, int64 shard_count, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt, + const absl::optional& layout = absl::nullopt); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then // broadcasting the reduction result to those cores. The reduction function is @@ -1884,6 +1905,9 @@ XlaOp Imag(XlaOp operand); // Enqueues a sqrt computation onto the computation. XlaOp Sqrt(XlaOp operand); +// Enqueues a cbrt computation onto the computation. +XlaOp Cbrt(XlaOp operand); + // Enqueues a rsqrt computation onto the computation. XlaOp Rsqrt(XlaOp operand); diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 1fa839b2014..e1733cd179c 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -381,6 +381,18 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllGather) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllGather(x, /*all_gather_dimension=*/1, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); +} + TEST_F(XlaBuilderTest, AllToAll) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index e6d60e51e75..60a563ee956 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -64,6 +64,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_force_host_platform_device_count(1); opts.set_xla_gpu_deterministic_reductions(false); opts.set_xla_cpu_enable_xprof_traceme(true); + // TODO(b/155295372): disable ptxas fallback by default. + opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(true); + opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_error(false); return opts; } @@ -219,340 +222,347 @@ static void AllocateFlags() { return true; }; - flag_objects = new std::vector({ - tensorflow::Flag( - "xla_cpu_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), - "Enable unsafe fast-math optimizations in the CPU compiler; " - "this may produce faster code at the expense of some accuracy."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_nans", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), - flag_values->xla_cpu_fast_math_honor_nans(), - "When xla_cpu_enable_fast_math is true then this controls whether we " - "allow operations to produce NaNs. Ignored when " - "xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_infs", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), - flag_values->xla_cpu_fast_math_honor_infs(), - "When xla_cpu_enable_fast_math is true then this controls whether we " - "allow operations to produce infinites. Ignored when " - "xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_division", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division), - flag_values->xla_cpu_fast_math_honor_division(), - "When xla_cpu_enable_fast_math is true then this controls whether " - "we forbid to use multiplication by the reciprocal instead of " - "division. Ignored when xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_cpu_fast_math_honor_functions", - bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions), - flag_values->xla_cpu_fast_math_honor_functions(), - "When xla_cpu_enable_fast_math is true then this controls whether " - "we forbid to approximate calculations for functions. Ignored when " - "xla_cpu_enable_fast_math is false."), - tensorflow::Flag( - "xla_gpu_enable_fast_min_max", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), - flag_values->xla_gpu_enable_fast_min_max(), - "Enable fast floating point min/max lowering that does not propagate " - "NaNs."), - tensorflow::Flag( - "xla_llvm_enable_alias_scope_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), - flag_values->xla_llvm_enable_alias_scope_metadata(), - "In LLVM-based backends, enable the emission of " - "!alias.scope metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_noalias_metadata", - bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), - flag_values->xla_llvm_enable_noalias_metadata(), - "In LLVM-based backends, enable the emission of " - "!noalias metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_invariant_load_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), - flag_values->xla_llvm_enable_invariant_load_metadata(), - "In LLVM-based backends, enable the emission of " - "!invariant.load metadata in " - "the generated IR."), - tensorflow::Flag( - "xla_llvm_disable_expensive_passes", - bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes), - flag_values->xla_llvm_disable_expensive_passes(), - "In LLVM-based backends, disable a custom set of " - "expensive optimization passes."), - tensorflow::Flag( - "xla_backend_optimization_level", - int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), - flag_values->xla_backend_optimization_level(), - "Numerical optimization level for the XLA compiler backend."), - tensorflow::Flag( - "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", - "Comma-separated list of hlo passes to be disabled. These names " - "must exactly match the passes' names; no whitespace around " - "commas."), - tensorflow::Flag( - "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, - "", - "Comma-separated list of hlo passes to be enabled. These names " - "must exactly match the passes' names; no whitespace around " - "commas. The unspecified passes are all disabled."), - tensorflow::Flag( - "xla_disable_all_hlo_passes", - bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, - "Disables all HLO passes. Notes that some passes are necessary for " - "correctness and the invariants that must be satisfied by 'fully " - "optimized' HLO are different for different devices and may change " - "over time. The only 'guarantee', such as it is, is that if you " - "compile XLA and dump the optimized HLO for some graph, you should " - "be able to run it again on the same device with the same build of " - "XLA."), - tensorflow::Flag( - "xla_embed_ir_in_executable", - bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), - flag_values->xla_embed_ir_in_executable(), - "Embed the compiler IR as a string in the executable."), - tensorflow::Flag( - "xla_eliminate_hlo_implicit_broadcast", - bool_setter_for( - &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), - flag_values->xla_eliminate_hlo_implicit_broadcast(), - "Eliminate implicit broadcasts when lowering user " - "computations to HLO instructions; use explicit " - "broadcast instead."), - tensorflow::Flag( - "xla_cpu_multi_thread_eigen", - bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), - flag_values->xla_cpu_multi_thread_eigen(), - "When generating calls to Eigen in the CPU backend, " - "use multi-threaded Eigen mode."), - tensorflow::Flag("xla_gpu_cuda_data_dir", - flag_values->mutable_xla_gpu_cuda_data_dir(), - "If non-empty, specifies a local directory containing " - "ptxas and nvvm libdevice files; otherwise we use " - "those from runfile directories."), - tensorflow::Flag("xla_gpu_ftz", - bool_setter_for(&DebugOptions::set_xla_gpu_ftz), - flag_values->xla_gpu_ftz(), - "If true, flush-to-zero semantics are enabled in the " - "code generated for GPUs."), - tensorflow::Flag( - "xla_gpu_disable_multi_streaming", - bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), - flag_values->xla_gpu_disable_multi_streaming(), - "If true, multi-streaming in the GPU backend is disabled."), - tensorflow::Flag( - "xla_gpu_max_kernel_unroll_factor", - int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), - flag_values->xla_gpu_max_kernel_unroll_factor(), - "Specify the maximum kernel unroll factor for the GPU backend."), - tensorflow::Flag("xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "", - "If non-empty, specifies a file containing ptx to use. " - "The filename prefix must have the same pattern as PTX " - "dumped by XLA. This allows to match one specific " - "module. General workflow. Get the generated module " - "ptx from XLA. Modify it. Then pass it back via this " - "option."), - tensorflow::Flag( - "xla_test_all_output_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), - flag_values->xla_test_all_output_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of output layouts. For example, with " - "a 3D shape, all permutations of the set {0, 1, 2} are " - "tried."), - tensorflow::Flag( - "xla_test_all_input_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), - flag_values->xla_test_all_input_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of *input* layouts. For example, for " - "2 input arguments with 2D shape and 4D shape, the " - "computation will run 2! * 4! times for every possible " - "layouts"), - tensorflow::Flag( - "xla_hlo_profile", - bool_setter_for(&DebugOptions::set_xla_hlo_profile), - flag_values->xla_hlo_profile(), - "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag("xla_backend_extra_options", - setter_for_xla_backend_extra_options, "", - "Extra options to pass to a backend; " - "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas."), - tensorflow::Flag( - "xla_gpu_use_cudnn_batchnorm", - bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm), - flag_values->xla_gpu_use_cudnn_batchnorm(), - "Allows the GPU backend to implement batchnorm HLOs using cudnn, " - "rather than expanding them to a soup of HLOs."), + flag_objects = new std::vector(); + flag_objects->reserve(55); + // Don't use an initializer list for initializing the vector; this would + // create a temporary copy, and exceeds the stack space when compiling with + // certain configurations. + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the CPU compiler; this may " + "produce faster code at the expense of some accuracy.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_nans", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), + flag_values->xla_cpu_fast_math_honor_nans(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce NaNs. Ignored when " + "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_infs", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), + flag_values->xla_cpu_fast_math_honor_infs(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "allow operations to produce infinites. Ignored when " + "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_division", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division), + flag_values->xla_cpu_fast_math_honor_division(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "forbid to use multiplication by the reciprocal instead of division. " + "Ignored when xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_fast_math_honor_functions", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions), + flag_values->xla_cpu_fast_math_honor_functions(), + "When xla_cpu_enable_fast_math is true then this controls whether we " + "forbid to approximate calculations for functions. Ignored when " + "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), + flag_values->xla_gpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that does not propagate " + "NaNs.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_enable_alias_scope_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_alias_scope_metadata), + flag_values->xla_llvm_enable_alias_scope_metadata(), + "In LLVM-based backends, enable the emission of !alias.scope metadata in " + "the generated IR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_enable_noalias_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), + flag_values->xla_llvm_enable_noalias_metadata(), + "In LLVM-based backends, enable the emission of !noalias metadata in the " + "generated IR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_enable_invariant_load_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), + flag_values->xla_llvm_enable_invariant_load_metadata(), + "In LLVM-based backends, enable the emission of !invariant.load metadata " + "in the generated IR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_llvm_disable_expensive_passes", + bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes), + flag_values->xla_llvm_disable_expensive_passes(), + "In LLVM-based backends, disable a custom set of expensive optimization " + "passes.")); + flag_objects->push_back(tensorflow::Flag( + "xla_backend_optimization_level", + int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), + flag_values->xla_backend_optimization_level(), + "Numerical optimization level for the XLA compiler backend.")); + flag_objects->push_back(tensorflow::Flag( + "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", + "Comma-separated list of hlo passes to be disabled. These names must " + "exactly match the passes' names; no whitespace around commas.")); + flag_objects->push_back(tensorflow::Flag( + "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, "", + "Comma-separated list of hlo passes to be enabled. These names must " + "exactly match the passes' names; no whitespace around commas. The " + "unspecified passes are all disabled.")); + flag_objects->push_back(tensorflow::Flag( + "xla_disable_all_hlo_passes", + bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, + "Disables all HLO passes. Notes that some passes are necessary for " + "correctness and the invariants that must be satisfied by 'fully " + "optimized' HLO are different for different devices and may change " + "over time. The only 'guarantee', such as it is, is that if you compile " + "XLA and dump the optimized HLO for some graph, you should be able to " + "run it again on the same device with the same build of XLA.")); + flag_objects->push_back(tensorflow::Flag( + "xla_embed_ir_in_executable", + bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), + flag_values->xla_embed_ir_in_executable(), + "Embed the compiler IR as a string in the executable.")); + flag_objects->push_back(tensorflow::Flag( + "xla_eliminate_hlo_implicit_broadcast", + bool_setter_for(&DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), + flag_values->xla_eliminate_hlo_implicit_broadcast(), + "Eliminate implicit broadcasts when lowering user computations to HLO " + "instructions; use explicit broadcast instead.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_multi_thread_eigen", + bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), + flag_values->xla_cpu_multi_thread_eigen(), + "When generating calls to Eigen in the CPU backend, use multi-threaded " + "Eigen mode.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_cuda_data_dir", flag_values->mutable_xla_gpu_cuda_data_dir(), + "If non-empty, specifies a local directory containing ptxas and nvvm " + "libdevice files; otherwise we use those from runfile directories.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_ftz", bool_setter_for(&DebugOptions::set_xla_gpu_ftz), + flag_values->xla_gpu_ftz(), + "If true, flush-to-zero semantics are enabled in the code generated for " + "GPUs.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_disable_multi_streaming", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), + flag_values->xla_gpu_disable_multi_streaming(), + "If true, multi-streaming in the GPU backend is disabled.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_max_kernel_unroll_factor", + int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), + flag_values->xla_gpu_max_kernel_unroll_factor(), + "Specify the maximum kernel unroll factor for the GPU backend.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "", + "If non-empty, specifies a file containing ptx to use. The filename " + "prefix must have the same pattern as PTX dumped by XLA. This allows to " + "match one specific module. General workflow. Get the generated module " + "ptx from XLA. Modify it. Then pass it back via this option.")); + flag_objects->push_back(tensorflow::Flag( + "xla_test_all_output_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), + flag_values->xla_test_all_output_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of " + "output layouts. For example, with a 3D shape, all permutations of the " + "set {0, 1, 2} are tried.")); + flag_objects->push_back(tensorflow::Flag( + "xla_test_all_input_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), + flag_values->xla_test_all_input_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of " + "*input* layouts. For example, for 2 input arguments with 2D shape and " + "4D shape, the computation will run 2! * 4! times for every possible " + "layouts")); + flag_objects->push_back(tensorflow::Flag( + "xla_hlo_profile", bool_setter_for(&DebugOptions::set_xla_hlo_profile), + flag_values->xla_hlo_profile(), + "Instrument the computation to collect per-HLO cycle counts")); + flag_objects->push_back(tensorflow::Flag( + "xla_backend_extra_options", setter_for_xla_backend_extra_options, "", + "Extra options to pass to a backend; comma-separated list of 'key=val' " + "strings (=val may be omitted); no whitespace around commas.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_use_cudnn_batchnorm", + bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm), + flag_values->xla_gpu_use_cudnn_batchnorm(), + "Allows the GPU backend to implement batchnorm HLOs using cudnn, rather " + "than expanding them to a soup of HLOs.")); + flag_objects->push_back( tensorflow::Flag("xla_cpu_use_mkl_dnn", bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), flag_values->xla_cpu_use_mkl_dnn(), - "Generate calls to MKL-DNN in the CPU backend."), - tensorflow::Flag( - "xla_gpu_crash_on_verification_failures", - bool_setter_for( - &DebugOptions::set_xla_gpu_crash_on_verification_failures), - flag_values->xla_gpu_crash_on_verification_failures(), - "Crashes the program on extra verification failures, e.g. cuDNN " - "cross checking failures"), - tensorflow::Flag( - "xla_gpu_autotune_level", - int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), - flag_values->xla_gpu_autotune_level(), - "Set GEMM and Convolution auto-tuning level." - "0 = off; 1 = on; 2 = on+init; 3 = on+init+reinit; 4 = " - "on+init+reinit+check."), - tensorflow::Flag( - "xla_force_host_platform_device_count", - int32_setter_for( - &DebugOptions::set_xla_force_host_platform_device_count), - flag_values->xla_force_host_platform_device_count(), - "Force the host platform to pretend that there are these many " - "host \"devices\". All of these host devices are backed by the same" - "threadpool. Setting this to anything other than 1 can increase " - "overhead from context switching but we let the user override this " - "behavior to help run tests on the host that run models in parallel " - "across multiple devices."), - tensorflow::Flag( - "xla_gpu_disable_gpuasm_optimizations", - bool_setter_for( - &DebugOptions::set_xla_gpu_disable_gpuasm_optimizations), - flag_values->xla_gpu_disable_gpuasm_optimizations(), - "In XLA:GPU run ptxas in -O0 (default is -O3)."), - tensorflow::Flag( - "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", - "Sets compiler fuel, useful for bisecting bugs in passes. Format " - "--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."), - - tensorflow::Flag( - "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), - flag_values->xla_dump_to(), - "Directory into which debugging data is written. If not specified " - "but another dumping flag is passed, data will be written to stdout. " - " To explicitly write to stdout, set this to \"-\". The values " - "\"sponge\" and \"test_undeclared_outputs_dir\" have a special " - "meaning: They cause us to dump into the directory specified by the " - "environment variable TEST_UNDECLARED_OUTPUTS_DIR."), - tensorflow::Flag( - "xla_dump_hlo_as_text", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), - flag_values->xla_dump_hlo_as_text(), - "Dumps HLO modules as text before and after optimizations. Results " - "are written to the --xla_dump_to dir, or, if no dir is specified, " - "to stdout."), - tensorflow::Flag( - "xla_dump_hlo_as_proto", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto), - flag_values->xla_dump_hlo_as_proto(), - "Dumps HLO modules as HloProtos to the directory specified by " - "--xla_dump_to."), - tensorflow::Flag( - "xla_dump_hlo_as_dot", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), - flag_values->xla_dump_hlo_as_dot(), - "Dumps HLO modules rendered as dot files to the directory " - "specified by --xla_dump_to."), + "Generate calls to MKL-DNN in the CPU backend.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_crash_on_verification_failures", + bool_setter_for( + &DebugOptions::set_xla_gpu_crash_on_verification_failures), + flag_values->xla_gpu_crash_on_verification_failures(), + "Crashes the program on extra verification failures, e.g. cuDNN cross " + "checking failures")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_autotune_level", + int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), + flag_values->xla_gpu_autotune_level(), + "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = " + "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check.")); + flag_objects->push_back(tensorflow::Flag( + "xla_force_host_platform_device_count", + int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count), + flag_values->xla_force_host_platform_device_count(), + "Force the host platform to pretend that there are these many host " + "\"devices\". All of these host devices are backed by the same " + "threadpool. Setting this to anything other than 1 can increase overhead " + "from context switching but we let the user override this behavior to " + "help run tests on the host that run models in parallel across multiple " + "devices.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_disable_gpuasm_optimizations", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations), + flag_values->xla_gpu_disable_gpuasm_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3).")); + flag_objects->push_back(tensorflow::Flag( + "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", + "Sets compiler fuel, useful for bisecting bugs in passes. Format " + "--xla_fuel=PASS1=NUM1,PASS2=NUM2,...")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), + flag_values->xla_dump_to(), + "Directory into which debugging data is written. If not specified but " + "another dumping flag is passed, data will be written to stdout. To " + "explicitly write to stdout, set this to \"-\". The values \"sponge\" " + "and \"test_undeclared_outputs_dir\" have a special meaning: They cause " + "us to dump into the directory specified by the environment variable " + "TEST_UNDECLARED_OUTPUTS_DIR.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_as_text", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), + flag_values->xla_dump_hlo_as_text(), + "Dumps HLO modules as text before and after optimizations. Results are " + "written to the --xla_dump_to dir, or, if no dir is specified, to " + "stdout.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_as_proto", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto), + flag_values->xla_dump_hlo_as_proto(), + "Dumps HLO modules as HloProtos to the directory specified by " + "--xla_dump_to.")); + flag_objects->push_back( + tensorflow::Flag("xla_dump_hlo_as_dot", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), + flag_values->xla_dump_hlo_as_dot(), + "Dumps HLO modules rendered as dot files to the " + "directory specified by --xla_dump_to.")); + flag_objects->push_back( tensorflow::Flag("xla_dump_hlo_as_html", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html), flag_values->xla_dump_hlo_as_html(), "Dumps HLO modules rendered as HTML files to the " - "directory specified by --xla_dump_to."), - tensorflow::Flag( - "xla_dump_hlo_as_url", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url), - flag_values->xla_dump_hlo_as_url(), - "Tries to dump HLO modules rendered as URLs to stdout (and also to " - "the directory specified by --xla_dump_to). This is not implemented " - "by default; you need to add a plugin which calls " - "RegisterGraphToURLRenderer()."), - tensorflow::Flag( - "xla_dump_hlo_snapshots", - bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), - flag_values->xla_dump_hlo_snapshots(), - "Every time an HLO module is run, dumps an HloSnapshot to the " - "directory specified by --xla_dump_to."), - tensorflow::Flag( - "xla_dump_hlo_module_re", - string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re), - flag_values->xla_dump_hlo_module_re(), - "Limits dumping only to modules which match this regular expression. " - " Default is to dump all modules."), - tensorflow::Flag( - "xla_dump_hlo_pass_re", - string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re), - flag_values->xla_dump_hlo_pass_re(), - "If specified, dumps HLO before and after optimization passes which " - "match this regular expression, in addition to dumping at the very " - "beginning and end of compilation."), - tensorflow::Flag( - "xla_dump_include_timestamp", - bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp), - flag_values->xla_dump_include_timestamp(), - "If specified, includes a timestamp in the dumped filenames."), - tensorflow::Flag( - "xla_dump_max_hlo_modules", - int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules), - flag_values->xla_dump_max_hlo_modules(), - "Max number of hlo module dumps in a directory. Set to < 0 for " - "unbounded."), - tensorflow::Flag( - "xla_hlo_graph_addresses", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), - flag_values->xla_hlo_graph_addresses(), - "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays " - "the address in memory of each HloInstruction object."), - tensorflow::Flag( - "xla_hlo_graph_sharding_color", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), - flag_values->xla_hlo_graph_sharding_color(), - "Assign colors based on sharding assignments when generating the " - "HLO graphs."), - tensorflow::Flag( - "xla_allow_excess_precision", - bool_setter_for(&DebugOptions::set_xla_allow_excess_precision), - flag_values->xla_allow_excess_precision(), - "Allow xla to increase the output precision of an instruction."), - tensorflow::Flag( - "xla_gpu_force_conv_nchw", - bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw), - flag_values->xla_gpu_force_conv_nchw(), - "For cuDNN convolutions, always NCHW layouts."), - tensorflow::Flag("xla_gpu_algorithm_blacklist_path", - string_setter_for( - &DebugOptions::set_xla_gpu_algorithm_blacklist_path), - flag_values->xla_gpu_algorithm_blacklist_path(), - "An AlgorithmBlacklist text proto file as a blacklist " - "of convolutions to avoid to use."), - tensorflow::Flag( - "xla_gpu_deterministic_reductions", - bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions), - flag_values->xla_gpu_deterministic_reductions(), - "Always run deterministic reductions on GPU"), - tensorflow::Flag( - "xla_tpu_detect_nan", - bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan), - flag_values->xla_tpu_detect_nan(), - "Trigger error on execution on TPU if a NAN value is detected"), - tensorflow::Flag( - "xla_tpu_detect_inf", - bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf), - flag_values->xla_tpu_detect_inf(), - "Trigger error on execution on TPU if a INF value is detected"), - tensorflow::Flag( - "xla_cpu_enable_xprof_traceme", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme), - flag_values->xla_cpu_enable_xprof_traceme(), - "If true, XLA CPU generates code to call " - "TraceMe::Activity{Start|End} around HLO operations."), - }); + "directory specified by --xla_dump_to.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_as_url", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url), + flag_values->xla_dump_hlo_as_url(), + "Tries to dump HLO modules rendered as URLs to stdout (and also to the " + "directory specified by --xla_dump_to). This is not implemented by " + "default; you need to add a plugin which calls " + "RegisterGraphToURLRenderer().")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_snapshots", + bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), + flag_values->xla_dump_hlo_snapshots(), + "Every time an HLO module is run, dumps an HloSnapshot to the directory " + "specified by --xla_dump_to.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_module_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re), + flag_values->xla_dump_hlo_module_re(), + "Limits dumping only to modules which match this regular expression. " + "Default is to dump all modules.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_hlo_pass_re", + string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re), + flag_values->xla_dump_hlo_pass_re(), + "If specified, dumps HLO before and after optimization passes which " + "match this regular expression, in addition to dumping at the very " + "beginning and end of compilation.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_include_timestamp", + bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp), + flag_values->xla_dump_include_timestamp(), + "If specified, includes a timestamp in the dumped filenames.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_max_hlo_modules", + int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules), + flag_values->xla_dump_max_hlo_modules(), + "Max number of hlo module dumps in a directory. Set to < 0 for " + "unbounded.")); + flag_objects->push_back(tensorflow::Flag( + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), + "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays " + "the address in memory of each HloInstruction object.")); + flag_objects->push_back(tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the HLO " + "graphs.")); + flag_objects->push_back(tensorflow::Flag( + "xla_allow_excess_precision", + bool_setter_for(&DebugOptions::set_xla_allow_excess_precision), + flag_values->xla_allow_excess_precision(), + "Allow xla to increase the output precision of an instruction.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_force_conv_nchw", + bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw), + flag_values->xla_gpu_force_conv_nchw(), + "For cuDNN convolutions, always NCHW layouts.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_algorithm_blacklist_path", + string_setter_for(&DebugOptions::set_xla_gpu_algorithm_blacklist_path), + flag_values->xla_gpu_algorithm_blacklist_path(), + "An AlgorithmBlacklist text proto file as a blacklist of convolutions to " + "avoid to use.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_deterministic_reductions", + bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions), + flag_values->xla_gpu_deterministic_reductions(), + "Always run deterministic reductions on GPU")); + flag_objects->push_back(tensorflow::Flag( + "xla_tpu_detect_nan", + bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan), + flag_values->xla_tpu_detect_nan(), + "Trigger error on execution on TPU if a NAN value is detected")); + flag_objects->push_back(tensorflow::Flag( + "xla_tpu_detect_inf", + bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf), + flag_values->xla_tpu_detect_inf(), + "Trigger error on execution on TPU if a INF value is detected")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_xprof_traceme", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme), + flag_values->xla_cpu_enable_xprof_traceme(), + "If true, XLA CPU generates code to call " + "TraceMe::Activity{Start|End} around HLO operations.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found), + flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(), + "If true, XLA GPU falls back to the driver if ptxas is not found. Note " + "that falling back to the driver can have drawbacks like using more " + "memory and/or other bugs during compilation, so we recommend setting " + "this flag to false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_gpu_unsafe_fallback_to_driver_on_ptxas_error", + bool_setter_for( + &DebugOptions::set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_error), + flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_error(), + "If true, XLA GPU falls back to the driver if there is an error when " + "running ptxas. Note that falling back to the driver can have drawbacks " + "like using more memory and/or other bugs during compilation, so we " + "recommend setting this flag to false.")); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 43ee0fdd820..8ae8c418d5d 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -50,6 +50,7 @@ class RunId { public: // Creates a new, unique RunId. RunId(); + explicit RunId(int64 value) : data_(value) {} RunId(const RunId&) = default; RunId& operator=(const RunId&) = default; diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index b89bfd68073..212ad87d94c 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -243,3 +243,54 @@ def split(tensor, tensor, split_dimension, num_devices, input_shape).apply_to_tensor( tensor, assign_tuple_sharding=assign_tuple_sharding) return tensor + + +def get_op_sharding(op): + """Returns sharding attribute of an op. + + Args: + op: a TensorFlow op. + + Returns: + The attribute representing XLA sharding on this op. + """ + return op.get_attr('_XlaSharding') + + +def auto_to_manual_spmd_partition(tensor, manual_sharding): + """Switches from automatic SPMD partitioning to manual partitioning. + + Converts a full-shaped tensor (to be automatically partitioned by SPMD + partitioner) to a shard-shaped tensor to be consumed by manually partitioned + ops. + + Args: + tensor: A tf.Tensor in full shape. + manual_sharding: a serialized string of OpSharding to be used in manual + partitioning. + + Returns: + A shard-shaped tensor to be consumed by manually partitioned ops. + """ + return tf2xla.spmd_full_to_shard_shape( + tensor, manual_sharding=manual_sharding) + + +def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape): + """Switches from manual partitioning to automatic SPMD partitioning. + + Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a + full-shaped tensor to be partitioned automatically by the SPMD partitioner. + + Args: + tensor: A tf.Tensor in shard shape. + manual_sharding: a serialized string of OpSharding to be used in manual + partitioning. + full_shape: the shape of tensor before partitioning. + + Returns: + A full-shaped tensor to be partitioned automatically by the SPMD + partitioner. + """ + return tf2xla.spmd_shard_to_full_shape( + tensor, manual_sharding=manual_sharding, full_shape=full_shape) diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD new file mode 100644 index 00000000000..dbd33705d0e --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -0,0 +1,213 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "worker_thread", + srcs = ["worker_thread.cc"], + hdrs = ["worker_thread.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "event_pool", + srcs = ["event_pool.cc"], + hdrs = ["event_pool.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "semaphore", + srcs = ["semaphore.cc"], + hdrs = ["semaphore.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "semaphore_test", + srcs = ["semaphore_test.cc"], + deps = [ + ":semaphore", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "tracked_device_buffer", + srcs = ["tracked_device_buffer.cc"], + hdrs = ["tracked_device_buffer.h"], + deps = [ + ":event_pool", + ":local_device_state", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:lib", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor:event", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "tracked_device_buffer_test", + srcs = ["tracked_device_buffer_test.cc"], + deps = [ + ":tracked_device_buffer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:device_memory_allocator", + ], +) + +cc_library( + name = "local_device_state", + srcs = ["local_device_state.cc"], + hdrs = ["local_device_state.h"], + deps = [ + ":event_pool", + ":semaphore", + ":worker_thread", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", + "//tensorflow/stream_executor:event", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "pjrt_client", + srcs = ["pjrt_client.cc"], + hdrs = ["pjrt_client.h"], + visibility = ["//tensorflow/compiler/xla:friends"], + deps = [ + ":event_pool", + ":local_device_state", + ":tracked_device_buffer", + "//tensorflow/compiler/xla:cpu_function_runtime", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/pjrt/distributed:protocol_proto_cc", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:maybe_owning_device_memory", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", + "//tensorflow/core:allocator", + "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:event", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/host:host_platform_id", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "cpu_device", + srcs = ["cpu_device.cc"], + hdrs = ["cpu_device.h"], + deps = [ + ":pjrt_client", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:platform_util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "nvidia_gpu_device", + srcs = ["nvidia_gpu_device.cc"], + hdrs = ["nvidia_gpu_device.h"], + copts = if_cuda(["-DNCCL_ENABLED=1"]), + deps = [ + ":pjrt_client", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/pjrt/distributed:client", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core/common_runtime:bfc_allocator", + "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", + "//tensorflow/stream_executor:tf_allocator_adapter", + ] + if_cuda(["@local_config_nccl//:nccl"]), +) + +tf_cc_test( + name = "gpu_multistream_test", + srcs = ["gpu_multistream_test.cc"], + tags = [ + # TODO(phawkins): figure out TF test infra such that this only runs under GPU. + "no_oss", + "requires-gpu-nvidia", + ], + deps = [ + ":nvidia_gpu_device", + ":pjrt_client", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:random", + ], +) diff --git a/tensorflow/compiler/xla/python/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc similarity index 82% rename from tensorflow/compiler/xla/python/cpu_device.cc rename to tensorflow/compiler/xla/pjrt/cpu_device.cc index 12e1e55723b..75c3bfc1277 100644 --- a/tensorflow/compiler/xla/python/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -40,8 +41,14 @@ StatusOr> GetCpuClient(bool asynchronous) { std::vector> devices; for (int i = 0; i < client->device_count(); ++i) { - se::StreamExecutor* executor = - client->backend().stream_executor(i).ValueOrDie(); + se::StreamExecutorConfig config; + config.ordinal = i; + // 8MiB stacks seem to be necessary for running LAPACK/OpenBLAS + // computations. + config.device_options.non_portable_tags["host_thread_stack_size_in_bytes"] = + absl::StrCat(8192 * 1024); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + platform->GetExecutor(config)); auto device_state = absl::make_unique( executor, client, LocalDeviceState::kSynchronous, asynchronous, /*allow_event_reuse=*/false); diff --git a/tensorflow/compiler/xla/python/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h similarity index 81% rename from tensorflow/compiler/xla/python/cpu_device.h rename to tensorflow/compiler/xla/pjrt/cpu_device.h index 38e81644b1e..c70d90ae228 100644 --- a/tensorflow/compiler/xla/python/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ #include -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -32,4 +32,4 @@ StatusOr> GetCpuClient(bool asynchronous); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_CPU_DEVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_CPU_DEVICE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/BUILD b/tensorflow/compiler/xla/pjrt/distributed/BUILD similarity index 100% rename from tensorflow/compiler/xla/python/distributed/BUILD rename to tensorflow/compiler/xla/pjrt/distributed/BUILD diff --git a/tensorflow/compiler/xla/python/distributed/client.cc b/tensorflow/compiler/xla/pjrt/distributed/client.cc similarity index 94% rename from tensorflow/compiler/xla/python/distributed/client.cc rename to tensorflow/compiler/xla/pjrt/distributed/client.cc index c50c3f50a9d..830e512b156 100644 --- a/tensorflow/compiler/xla/python/distributed/client.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" #include // NOLINT -#include "tensorflow/compiler/xla/python/distributed/protocol.h" -#include "tensorflow/compiler/xla/python/distributed/util.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" +#include "tensorflow/compiler/xla/pjrt/distributed/util.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/distributed/client.h b/tensorflow/compiler/xla/pjrt/distributed/client.h similarity index 85% rename from tensorflow/compiler/xla/python/distributed/client.h rename to tensorflow/compiler/xla/pjrt/distributed/client.h index 1ab5292bea8..865a752849e 100644 --- a/tensorflow/compiler/xla/python/distributed/client.h +++ b/tensorflow/compiler/xla/pjrt/distributed/client.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ #include #include "grpcpp/channel.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/platform/env.h" @@ -47,4 +47,4 @@ class DistributedRuntimeClient { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_CLIENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/distributed/client_server_test.cc b/tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/client_server_test.cc rename to tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc index e78949933a2..cfe60a06207 100644 --- a/tensorflow/compiler/xla/python/distributed/client_server_test.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client_server_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include "grpcpp/security/server_credentials.h" #include "absl/time/time.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/python/distributed/distributed.cc b/tensorflow/compiler/xla/pjrt/distributed/distributed.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/distributed.cc rename to tensorflow/compiler/xla/pjrt/distributed/distributed.cc index 6afc7b1c4e9..7753e2dcfc7 100644 --- a/tensorflow/compiler/xla/python/distributed/distributed.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/distributed.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/distributed.h" +#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" #include "grpcpp/grpcpp.h" diff --git a/tensorflow/compiler/xla/python/distributed/distributed.h b/tensorflow/compiler/xla/pjrt/distributed/distributed.h similarity index 83% rename from tensorflow/compiler/xla/python/distributed/distributed.h rename to tensorflow/compiler/xla/pjrt/distributed/distributed.h index 0475c3e9feb..b3909387259 100644 --- a/tensorflow/compiler/xla/python/distributed/distributed.h +++ b/tensorflow/compiler/xla/pjrt/distributed/distributed.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ #include #include -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -43,4 +43,4 @@ std::shared_ptr GetDistributedRuntimeClient( } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_DISTRIBUTED_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_DISTRIBUTED_H_ diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.cc b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc similarity index 95% rename from tensorflow/compiler/xla/python/distributed/key_value_store.cc rename to tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc index 5966d4ce12b..e989b1384d2 100644 --- a/tensorflow/compiler/xla/python/distributed/key_value_store.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" +#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/distributed/key_value_store.h b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.h similarity index 89% rename from tensorflow/compiler/xla/python/distributed/key_value_store.h rename to tensorflow/compiler/xla/pjrt/distributed/key_value_store.h index 8560305e6f6..d496de1feb5 100644 --- a/tensorflow/compiler/xla/python/distributed/key_value_store.h +++ b/tensorflow/compiler/xla/pjrt/distributed/key_value_store.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ #include "grpcpp/grpcpp.h" #include "absl/base/thread_annotations.h" @@ -50,4 +50,4 @@ class KeyValueStore { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_KEY_VALUE_STORE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_KEY_VALUE_STORE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/protocol.h b/tensorflow/compiler/xla/pjrt/distributed/protocol.h similarity index 80% rename from tensorflow/compiler/xla/python/distributed/protocol.h rename to tensorflow/compiler/xla/pjrt/distributed/protocol.h index 208c6dab8c5..4daa939ac8d 100644 --- a/tensorflow/compiler/xla/python/distributed/protocol.h +++ b/tensorflow/compiler/xla/pjrt/distributed/protocol.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ namespace xla { @@ -22,4 +22,4 @@ static constexpr int kDistributedRuntimeProtocolVersion = 1; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_PROTOCOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_PROTOCOL_H_ diff --git a/tensorflow/compiler/xla/python/distributed/protocol.proto b/tensorflow/compiler/xla/pjrt/distributed/protocol.proto similarity index 100% rename from tensorflow/compiler/xla/python/distributed/protocol.proto rename to tensorflow/compiler/xla/pjrt/distributed/protocol.proto diff --git a/tensorflow/compiler/xla/python/distributed/service.cc b/tensorflow/compiler/xla/pjrt/distributed/service.cc similarity index 96% rename from tensorflow/compiler/xla/python/distributed/service.cc rename to tensorflow/compiler/xla/pjrt/distributed/service.cc index cc2b3a5aca2..3325fcd8319 100644 --- a/tensorflow/compiler/xla/python/distributed/service.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.h" -#include "tensorflow/compiler/xla/python/distributed/util.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" +#include "tensorflow/compiler/xla/pjrt/distributed/util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/python/distributed/service.h b/tensorflow/compiler/xla/pjrt/distributed/service.h similarity index 91% rename from tensorflow/compiler/xla/python/distributed/service.h rename to tensorflow/compiler/xla/pjrt/distributed/service.h index baf470e4f13..725a76791ce 100644 --- a/tensorflow/compiler/xla/python/distributed/service.h +++ b/tensorflow/compiler/xla/pjrt/distributed/service.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ #include "absl/time/time.h" -#include "tensorflow/compiler/xla/python/distributed/key_value_store.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.grpc.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/key_value_store.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.grpc.pb.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -98,4 +98,4 @@ void BuildGlobalTopology(absl::Span local_topologies, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_SERVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_SERVICE_H_ diff --git a/tensorflow/compiler/xla/python/distributed/service_test.cc b/tensorflow/compiler/xla/pjrt/distributed/service_test.cc similarity index 91% rename from tensorflow/compiler/xla/python/distributed/service_test.cc rename to tensorflow/compiler/xla/pjrt/distributed/service_test.cc index 08326df2f38..b56dbb17d1a 100644 --- a/tensorflow/compiler/xla/python/distributed/service_test.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/python/distributed/util.h b/tensorflow/compiler/xla/pjrt/distributed/util.h similarity index 87% rename from tensorflow/compiler/xla/python/distributed/util.h rename to tensorflow/compiler/xla/pjrt/distributed/util.h index 07ae8d1f0ce..abb2b6089e7 100644 --- a/tensorflow/compiler/xla/python/distributed/util.h +++ b/tensorflow/compiler/xla/pjrt/distributed/util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ #include "grpcpp/support/status.h" #include "tensorflow/compiler/xla/status.h" @@ -41,4 +41,4 @@ inline ::grpc::Status ToGrpcStatus(const Status& s) { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DISTRIBUTED_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_DISTRIBUTED_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/event_pool.cc b/tensorflow/compiler/xla/pjrt/event_pool.cc similarity index 96% rename from tensorflow/compiler/xla/python/event_pool.cc rename to tensorflow/compiler/xla/pjrt/event_pool.cc index c7b52f523d9..86aa38cdd0f 100644 --- a/tensorflow/compiler/xla/python/event_pool.cc +++ b/tensorflow/compiler/xla/pjrt/event_pool.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" #include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" diff --git a/tensorflow/compiler/xla/python/event_pool.h b/tensorflow/compiler/xla/pjrt/event_pool.h similarity index 95% rename from tensorflow/compiler/xla/python/event_pool.h rename to tensorflow/compiler/xla/pjrt/event_pool.h index bda3fb6baff..47768c28fd9 100644 --- a/tensorflow/compiler/xla/python/event_pool.h +++ b/tensorflow/compiler/xla/pjrt/event_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ #include #include @@ -87,4 +87,4 @@ class EventPool { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_EVENT_POOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_EVENT_POOL_H_ diff --git a/tensorflow/compiler/xla/python/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc similarity index 97% rename from tensorflow/compiler/xla/python/gpu_multistream_test.cc rename to tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index bc6ecb14ae2..2db7de3720d 100644 --- a/tensorflow/compiler/xla/python/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/python/local_client.h" -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/random.h" diff --git a/tensorflow/compiler/xla/python/local_device_state.cc b/tensorflow/compiler/xla/pjrt/local_device_state.cc similarity index 98% rename from tensorflow/compiler/xla/python/local_device_state.cc rename to tensorflow/compiler/xla/pjrt/local_device_state.cc index 6a96908cb12..d173c891c95 100644 --- a/tensorflow/compiler/xla/python/local_device_state.cc +++ b/tensorflow/compiler/xla/pjrt/local_device_state.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include #include diff --git a/tensorflow/compiler/xla/python/local_device_state.h b/tensorflow/compiler/xla/pjrt/local_device_state.h similarity index 96% rename from tensorflow/compiler/xla/python/local_device_state.h rename to tensorflow/compiler/xla/pjrt/local_device_state.h index 5cd2c0014a0..eb25c37878f 100644 --- a/tensorflow/compiler/xla/python/local_device_state.h +++ b/tensorflow/compiler/xla/pjrt/local_device_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ #include #include @@ -22,9 +22,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/semaphore.h" -#include "tensorflow/compiler/xla/python/worker_thread.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/stream_executor.h" @@ -207,4 +207,4 @@ class LocalDeviceState { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_DEVICE_STATE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_LOCAL_DEVICE_STATE_H_ diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc similarity index 99% rename from tensorflow/compiler/xla/python/nvidia_gpu_device.cc rename to tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index 886ed697f4e..4863e5e8165 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" #ifdef NCCL_ENABLED #include "third_party/nccl/nccl.h" diff --git a/tensorflow/compiler/xla/python/nvidia_gpu_device.h b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h similarity index 87% rename from tensorflow/compiler/xla/python/nvidia_gpu_device.h rename to tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h index 2f9922454fa..bf59ddef3a9 100644 --- a/tensorflow/compiler/xla/python/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ #include -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/bfc_allocator.h" @@ -59,4 +59,4 @@ StatusOr> GetNvidiaGpuClient( } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NVIDIA_GPU_DEVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_NVIDIA_GPU_DEVICE_H_ diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc similarity index 99% rename from tensorflow/compiler/xla/python/local_client.cc rename to tensorflow/compiler/xla/pjrt/pjrt_client.cc index f2acd0d6398..80fd0e0b658 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -62,7 +62,7 @@ limitations under the License. // See the comment on LocalDeviceState::AllocationModel for a discussion of the // different allocation semantics on CPU, GPU, and TPU. -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include #include @@ -83,10 +83,10 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/python/distributed/protocol.pb.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h similarity index 99% rename from tensorflow/compiler/xla/python/local_client.h rename to tensorflow/compiler/xla/pjrt/pjrt_client.h index f09e70037d6..775b44c7073 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -681,4 +681,4 @@ class PjRtExecutable { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_CLIENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/semaphore.cc b/tensorflow/compiler/xla/pjrt/semaphore.cc similarity index 97% rename from tensorflow/compiler/xla/python/semaphore.cc rename to tensorflow/compiler/xla/pjrt/semaphore.cc index 5926618bddc..c1df52acc61 100644 --- a/tensorflow/compiler/xla/python/semaphore.cc +++ b/tensorflow/compiler/xla/pjrt/semaphore.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/python/semaphore.h b/tensorflow/compiler/xla/pjrt/semaphore.h similarity index 92% rename from tensorflow/compiler/xla/python/semaphore.h rename to tensorflow/compiler/xla/pjrt/semaphore.h index 7d3e9ce6271..45345becf74 100644 --- a/tensorflow/compiler/xla/python/semaphore.h +++ b/tensorflow/compiler/xla/pjrt/semaphore.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/types.h" @@ -65,4 +65,4 @@ class Semaphore { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_SEMAPHORE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_SEMAPHORE_H_ diff --git a/tensorflow/compiler/xla/python/semaphore_test.cc b/tensorflow/compiler/xla/pjrt/semaphore_test.cc similarity index 97% rename from tensorflow/compiler/xla/python/semaphore_test.cc rename to tensorflow/compiler/xla/pjrt/semaphore_test.cc index 5ef59618b8b..56f7e8c9a05 100644 --- a/tensorflow/compiler/xla/python/semaphore_test.cc +++ b/tensorflow/compiler/xla/pjrt/semaphore_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "absl/synchronization/notification.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/python/tracked_device_buffer.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc similarity index 98% rename from tensorflow/compiler/xla/python/tracked_device_buffer.cc rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc index 5c6dbbf3289..32ca4e4550c 100644 --- a/tensorflow/compiler/xla/python/tracked_device_buffer.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include #include #include "absl/synchronization/mutex.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/stream_executor/device_memory.h" diff --git a/tensorflow/compiler/xla/python/tracked_device_buffer.h b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h similarity index 97% rename from tensorflow/compiler/xla/python/tracked_device_buffer.h rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer.h index 27e7de6e2c2..562cb2f913e 100644 --- a/tensorflow/compiler/xla/python/tracked_device_buffer.h +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ #include #include "absl/container/flat_hash_set.h" -#include "tensorflow/compiler/xla/python/event_pool.h" -#include "tensorflow/compiler/xla/python/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/event_pool.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape.h" @@ -257,4 +257,4 @@ void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_TRACKED_DEVICE_BUFFER_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ diff --git a/tensorflow/compiler/xla/python/tracked_device_buffer_test.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc similarity index 98% rename from tensorflow/compiler/xla/python/tracked_device_buffer_test.cc rename to tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc index 354176654af..9373b57e7d1 100644 --- a/tensorflow/compiler/xla/python/tracked_device_buffer_test.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include diff --git a/tensorflow/compiler/xla/python/worker_thread.cc b/tensorflow/compiler/xla/pjrt/worker_thread.cc similarity index 96% rename from tensorflow/compiler/xla/python/worker_thread.cc rename to tensorflow/compiler/xla/pjrt/worker_thread.cc index d3fb02023a5..e8194534aef 100644 --- a/tensorflow/compiler/xla/python/worker_thread.cc +++ b/tensorflow/compiler/xla/pjrt/worker_thread.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/worker_thread.h" +#include "tensorflow/compiler/xla/pjrt/worker_thread.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/worker_thread.h b/tensorflow/compiler/xla/pjrt/worker_thread.h similarity index 90% rename from tensorflow/compiler/xla/python/worker_thread.h rename to tensorflow/compiler/xla/pjrt/worker_thread.h index 598f7b1d4ae..4fd2baa4cda 100644 --- a/tensorflow/compiler/xla/python/worker_thread.h +++ b/tensorflow/compiler/xla/pjrt/worker_thread.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ #include #include @@ -51,4 +51,4 @@ class WorkerThread { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_WORKER_THREAD_H_ +#endif // TENSORFLOW_COMPILER_XLA_PJRT_WORKER_THREAD_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 3eb93f9559e..8c6bc84cf8e 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,7 +1,5 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps") -load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") @@ -78,16 +76,6 @@ py_test( ] + xla_py_test_deps(), ) -cc_library( - name = "worker_thread", - srcs = ["worker_thread.cc"], - hdrs = ["worker_thread.h"], - deps = [ - "//tensorflow/core:lib", - "@com_google_absl//absl/synchronization", - ], -) - cc_library( name = "types", srcs = ["types.cc"], @@ -99,7 +87,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":bfloat16", - ":local_client", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -107,6 +94,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/core:lib", "//third_party/py/numpy:headers", "@com_google_absl//absl/container:flat_hash_map", @@ -116,148 +104,6 @@ cc_library( ], ) -cc_library( - name = "event_pool", - srcs = ["event_pool.cc"], - hdrs = ["event_pool.h"], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "semaphore", - srcs = ["semaphore.cc"], - hdrs = ["semaphore.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "semaphore_test", - srcs = ["semaphore_test.cc"], - deps = [ - ":semaphore", - "//tensorflow/compiler/xla:test", - "//tensorflow/core:lib", - "//tensorflow/core:test_main", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "tracked_device_buffer", - srcs = ["tracked_device_buffer.cc"], - hdrs = ["tracked_device_buffer.h"], - deps = [ - ":event_pool", - ":local_device_state", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/core:lib", - "//tensorflow/stream_executor:device_memory", - "//tensorflow/stream_executor:device_memory_allocator", - "//tensorflow/stream_executor:event", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "tracked_device_buffer_test", - srcs = ["tracked_device_buffer_test.cc"], - deps = [ - ":tracked_device_buffer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:cpu_plugin", - "//tensorflow/core:test_main", - "//tensorflow/stream_executor:device_memory", - "//tensorflow/stream_executor:device_memory_allocator", - ], -) - -cc_library( - name = "local_device_state", - srcs = ["local_device_state.cc"], - hdrs = ["local_device_state.h"], - deps = [ - ":event_pool", - ":semaphore", - ":worker_thread", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", - "//tensorflow/stream_executor:event", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "local_client", - srcs = ["local_client.cc"], - hdrs = ["local_client.h"], - visibility = ["//tensorflow/compiler/xla:friends"], - deps = [ - ":event_pool", - ":local_device_state", - ":tracked_device_buffer", - "//tensorflow/compiler/xla:cpu_function_runtime", - "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/python/distributed:protocol_proto_cc", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:maybe_owning_device_memory", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/core:allocator", - "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/stream_executor:event", - "//tensorflow/stream_executor:stream", - "//tensorflow/stream_executor/host:host_platform_id", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "python_ref_manager", srcs = ["python_ref_manager.cc"], @@ -322,10 +168,10 @@ cc_library( ], features = ["-use_header_modules"], deps = [ - ":local_client", - ":tracked_device_buffer", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:platform", "//tensorflow/stream_executor/cuda:cuda_platform_id", @@ -340,37 +186,6 @@ cc_library( ], ) -cc_library( - name = "cpu_device", - srcs = ["cpu_device.cc"], - hdrs = ["cpu_device.h"], - deps = [ - ":local_client", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:platform_util", - ], -) - -cc_library( - name = "nvidia_gpu_device", - srcs = ["nvidia_gpu_device.cc"], - hdrs = ["nvidia_gpu_device.h"], - copts = if_cuda(["-DNCCL_ENABLED=1"]), - deps = [ - ":local_client", - "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/python/distributed:client", - "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla:util", - "//tensorflow/core/common_runtime:bfc_allocator", - "//tensorflow/core/common_runtime/gpu:gpu_mem_allocator", - "//tensorflow/stream_executor:tf_allocator_adapter", - ] + if_cuda(["@local_config_nccl//:nccl"]), -) - config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -389,11 +204,7 @@ pybind_extension( module_name = "xla_extension", deps = [ ":bfloat16", - ":cpu_device", ":dlpack", - ":local_client", - ":nvidia_gpu_device", - ":tracked_device_buffer", ":python_ref_manager", ":types", "@com_google_absl//absl/base", @@ -423,9 +234,13 @@ pybind_extension( "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:svd", - "//tensorflow/compiler/xla/python/distributed", - "//tensorflow/compiler/xla/python/distributed:client", - "//tensorflow/compiler/xla/python/distributed:service", + "//tensorflow/compiler/xla/pjrt:cpu_device", + "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", + "//tensorflow/compiler/xla/pjrt/distributed", + "//tensorflow/compiler/xla/pjrt/distributed:client", + "//tensorflow/compiler/xla/pjrt/distributed:service", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:hlo", @@ -454,25 +269,3 @@ pybind_extension( "//conditions:default": [], }), ) - -tf_cc_test( - name = "gpu_multistream_test", - srcs = ["gpu_multistream_test.cc"], - tags = [ - # TODO(phawkins): figure out TF test infra such that this only runs under GPU. - "no_oss", - "requires-gpu-nvidia", - ], - deps = [ - ":local_client", - ":nvidia_gpu_device", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:gpu_plugin", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:lib", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:random", - ], -) diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 31f51d70937..d37d480607a 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "include/dlpack/dlpack.h" // from @dlpack -#include "tensorflow/compiler/xla/python/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" diff --git a/tensorflow/compiler/xla/python/dlpack.h b/tensorflow/compiler/xla/python/dlpack.h index 9d8965ac43d..6766bbe93b1 100644 --- a/tensorflow/compiler/xla/python/dlpack.h +++ b/tensorflow/compiler/xla/python/dlpack.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_DLPACK_H_ #include "pybind11/pybind11.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index b5f1a831d4a..c460cc36f08 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -19,8 +19,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/python:local_client", - "//tensorflow/compiler/xla/python:semaphore", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:semaphore", "//tensorflow/compiler/xla/python/tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:direct_tpu_driver", "//tensorflow/compiler/xla/python/tpu_driver:grpc_tpu_driver", diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index fe2cddd75ef..e78f04ff980 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/python/semaphore.h" +#include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index f2c792d2a20..4c45df181db 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h" #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index ef0caff0ae6..6d4482af43f 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -20,6 +20,9 @@ from __future__ import print_function from absl import logging +# Import xla_client to load shared C++ extensions (just CompileOptions at the +# time of writing). +from tensorflow.compiler.xla.python import xla_client # pylint: disable=unused-import from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index 4ed4e9cb7f8..673f403d91e 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -26,7 +26,7 @@ limitations under the License. #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 206c304abbb..f03595bf677 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -39,14 +39,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/distributed/client.h" +#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" +#include "tensorflow/compiler/xla/pjrt/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" -#include "tensorflow/compiler/xla/python/cpu_device.h" -#include "tensorflow/compiler/xla/python/distributed/client.h" -#include "tensorflow/compiler/xla/python/distributed/distributed.h" -#include "tensorflow/compiler/xla/python/distributed/service.h" #include "tensorflow/compiler/xla/python/dlpack.h" -#include "tensorflow/compiler/xla/python/local_client.h" -#include "tensorflow/compiler/xla/python/nvidia_gpu_device.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -980,10 +980,17 @@ PYBIND11_MODULE(xla_extension, m) { py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, device.GetLocalDeviceState()); + Shape shape_with_layout = shape; + ShapeUtil::ForEachMutableSubshape( + &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); TF_ASSIGN_OR_RETURN( Literal literal, local_device->client()->TransferFromOutfeedLocal( - shape, local_device->device_ordinal())); + shape_with_layout, local_device->device_ordinal())); literal_shared = std::make_shared(std::move(literal)); } diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 7f09a7e1698..d9cd906939d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -261,44 +261,6 @@ class ProgramShape(object): """ -class Buffer(object): - """Represents a handle to data owned by XLA. - - The referent is ready for use in executing a local, compiled - Computation. On XLA platforms involving a device (e.g. GPU), this - means the referent is in device memory. - """ - - @staticmethod - def from_pyval(pyval, device=None, backend=None, force_copy=False): - """Copies the `pyval` to a freshly allocated on-device buffer.""" - backend = backend or get_local_backend() - return backend.buffer_from_pyval(pyval, device, force_copy=force_copy) - - # Buffer is not an instantiable type and exists only for its static methods. - # The underlying buffer objects are C++ object with the following - # API: - # def shape(self) -> Shape: - # def device(self) -> int: - # def delete(self): - # def is_deleted(self) -> bool: - # def block_host_until_ready(self): - # """Blocks the calling thread until the buffer is ready on device.""" - # def copy_to_host_async(self): - # """Requests a copy of the buffer to the host. - # - # Does not block waiting for the copy. Values fetched are available via - # `to_py()`; the purpose of `copy_to_host_async` is to prefetch values - # for subsequent `to_py()` calls, especially when requesting many values - # at once. - # """ - # def to_py(self): - # """Returns the value of the buffer as a Python tuple tree of ndarrays.""" - # - # TODO(phawkins): remove Buffer and its static methods completely, have - # clients call methods on Backend to create buffers. - - def shape_from_pyval(pyval): """Returns a Shape that describes a tuple-tree of Numpy arrays.""" @@ -311,43 +273,6 @@ def shape_from_pyval(pyval): return convert(pyval) -def transfer_to_infeed(value, device=None): - """Transfers the given value into the XLA infeed queue. - - XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with - a totally ordered stream of values. This is dequeued from XLA computations via - the Infeed() operation. - - Args: - value: the value that the caller would like to enqueue into the XLA infeed - queue - device: the device to infeed the value to. Each device has a distinct infeed - queue. - """ - # TODO(phawkins): support non-default backends. - backend = get_local_backend() - device = device or backend.local_devices()[0] - device.transfer_to_infeed(value) - - -def transfer_from_outfeed(shape, device=None): - """Transfers a literal of the given shape from `device`'s outfeed. - - Args: - shape: The shape of the value to transfer from outfeed. - device: The device from which to transfer the outfeed value. Each device has - a distinct outfeed queue.. - - Returns: - The literal value that is produced from the outfeed queue. - """ - # TODO(phawkins): support non-default backends. - backend = get_local_backend() - device = device or backend.local_devices()[0] - return device.transfer_from_outfeed( - shape.with_major_to_minor_layout_if_absent()) - - DeviceAssignment = _xla.DeviceAssignment DeviceAssignment.__doc__ = """ A DeviceAssignment is a C++ object with the following signature. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 62b3fae018a..fbdd9921a40 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -2029,8 +2029,11 @@ def TestFactory(xla_backend, cloud_tpu=False): return tests -def InstantiateTests(globals_dict, backend, test_prefix="", **kw): - for klass in TestFactory(backend, **kw): +def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): + # Avoid creating a new backend per test (this causes GPU OOM, and is probably + # inefficient). + backend_fn = functools.lru_cache(maxsize=None)(backend_fn) + for klass in TestFactory(backend_fn, **kw): test = type(test_prefix + klass.__name__, (klass,), {}) # Clean up the qualified names of the tests to not include the test factory. test.__qualname__ = test.__name__ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aef215e23e8..126b62a8eb2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -460,6 +460,37 @@ cc_library( ], ) +cc_library( + name = "hlo_sharding_util", + srcs = [ + "hlo_sharding_util.cc", + ], + hdrs = [ + "hlo_sharding_util.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "hlo_sharding_util_test", + srcs = [ + "hlo_sharding_util_test.cc", + ], + deps = [ + ":hlo_sharding_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "dynamic_parameter_binding_test", srcs = ["dynamic_parameter_binding_test.cc"], @@ -2122,6 +2153,51 @@ tf_cc_test( ], ) +cc_library( + name = "conditional_code_motion", + srcs = ["conditional_code_motion.cc"], + hdrs = ["conditional_code_motion.h"], + deps = [ + ":call_graph", + ":call_inliner", + ":hlo", + ":hlo_casting_utils", + ":hlo_dce", + ":hlo_pass", + ":hlo_pass_pipeline", + ":tuple_simplifier", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "conditional_code_motion_test", + srcs = ["conditional_code_motion_test.cc"], + deps = [ + ":conditional_code_motion", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "convolution_group_converter", srcs = ["convolution_group_converter.cc"], @@ -2352,6 +2428,42 @@ tf_cc_test( ], ) +cc_library( + name = "all_gather_decomposer", + srcs = ["all_gather_decomposer.cc"], + hdrs = ["all_gather_decomposer.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "all_gather_decomposer_test", + srcs = ["all_gather_decomposer_test.cc"], + deps = [ + ":all_gather_decomposer", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -3189,6 +3301,29 @@ tf_cc_test( ], ) +cc_library( + name = "memory_space_propagation", + srcs = ["memory_space_propagation.cc"], + hdrs = ["memory_space_propagation.h"], + deps = [ + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_pass", + ], +) + +tf_cc_test( + name = "memory_space_propagation_test", + srcs = ["memory_space_propagation_test.cc"], + deps = [ + ":hlo_parser", + ":memory_space_propagation", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_dce", srcs = ["hlo_dce.cc"], @@ -3742,6 +3877,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:core", "@llvm-project//llvm:transform_utils", ], diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.cc b/tensorflow/compiler/xla/service/all_gather_decomposer.cc new file mode 100644 index 00000000000..ad63218eca8 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.cc @@ -0,0 +1,154 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_gather_decomposer.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +// Creates a computation of x + y. +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + if (type == PRED) { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); + } else { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + } + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +Status DecomposeAllGather(HloAllGatherInstruction* ag, int64 partition_count, + HloComputation* comp) { + auto zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(ag->shape().element_type()))); + zero = comp->AddInstruction( + HloInstruction::CreateBroadcast(ag->shape(), zero, {})); + auto zero_index = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector start_indices(ag->shape().rank(), zero_index); + auto shard_id_from_subgroup = [&](HloInstruction* replica_or_global_id) { + if (ag->replica_groups().empty()) { + return replica_or_global_id; + } + if (ag->replica_groups().size() == 1) { + // Whether the group is {1, 2, ..., N - 1}. + bool trivial_group = true; + for (int64 i = 0; i < ag->replica_groups()[0].replica_ids_size(); ++i) { + if (ag->replica_groups()[0].replica_ids(i) != i) { + trivial_group = false; + break; + } + } + if (trivial_group) { + CHECK_EQ(partition_count, ag->replica_groups()[0].replica_ids_size()); + return replica_or_global_id; + } + } + // Create a table of shard IDs for each replica_or_global_id, then slice it + // using replica_or_global_id. + std::vector shard_ids(ag->replica_groups().size() * + ag->replica_groups()[0].replica_ids_size()); + for (const auto& group : ag->replica_groups()) { + for (int64 i = 0; i < group.replica_ids_size(); ++i) { + shard_ids[group.replica_ids(i)] = i; + } + } + auto id_table = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(shard_ids))); + auto shard_id = comp->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), id_table, {replica_or_global_id}, {1})); + shard_id = comp->AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), shard_id)); + return shard_id; + }; + HloInstruction* shard_id; + if (ag->channel_id().has_value()) { + if (ag->use_global_device_ids()) { + auto pid = comp->AddInstruction(HloInstruction::CreatePartitionId()); + auto rid = comp->AddInstruction(HloInstruction::CreateReplicaId()); + auto pcount = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(partition_count))); + auto global_id = comp->AddInstruction(HloInstruction::CreateBinary( + pid->shape(), HloOpcode::kAdd, pid, + comp->AddInstruction(HloInstruction::CreateBinary( + pid->shape(), HloOpcode::kMultiply, rid, pcount)))); + shard_id = shard_id_from_subgroup(global_id); + } else { + TF_RET_CHECK(!ag->replica_groups().empty()); + TF_RET_CHECK(ag->replica_groups()[0].replica_ids_size() == 1); + shard_id = comp->AddInstruction(HloInstruction::CreatePartitionId()); + } + } else { + shard_id = shard_id_from_subgroup( + comp->AddInstruction(HloInstruction::CreateReplicaId())); + } + start_indices[ag->all_gather_dimension()] = + comp->AddInstruction(HloInstruction::CreateBinary( + shard_id->shape(), HloOpcode::kMultiply, shard_id, + comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(ag->operand(0)->shape().dimensions( + ag->all_gather_dimension())))))); + auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + zero->shape(), zero, ag->mutable_operand(0), start_indices)); + auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce( + dus->shape(), {dus}, + MakeBinaryAdd(dus->shape().element_type(), comp->parent()), + ag->replica_groups(), + /*constrain_layout=*/ag->constrain_layout(), ag->channel_id(), + ag->use_global_device_ids())); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(ar)); + TF_RETURN_IF_ERROR(comp->RemoveInstructionAndUnusedOperands(ag)); + return Status::OK(); +} + +StatusOr AllGatherDecomposer::Run(HloModule* module) { + bool changed = false; + for (auto comp : module->MakeNonfusionComputations()) { + for (auto hlo : comp->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kAllGather) { + continue; + } + auto ag = Cast(hlo); + if (should_decompose_(*ag)) { + TF_RETURN_IF_ERROR(DecomposeAllGather(ag, partition_count_, comp)); + changed = true; + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.h b/tensorflow/compiler/xla/service/all_gather_decomposer.h new file mode 100644 index 00000000000..d1983e37383 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// AllGatherDecomposer is a pass which converts unsupported all-gathers into +// dynamic-update-slices and all-reduces. +class AllGatherDecomposer : public HloModulePass { + public: + AllGatherDecomposer( + std::function should_decompose, + int64 partition_count) + : should_decompose_(std::move(should_decompose)), + partition_count_(partition_count) {} + explicit AllGatherDecomposer(int64 partition_count) + : should_decompose_( + [](const HloAllGatherInstruction& ag) { return true; }), + partition_count_(partition_count) {} + absl::string_view name() const override { return "all_gather_decomposer"; } + + // Run AllGatherDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + std::function should_decompose_; + int64 partition_count_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALL_GATHER_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc new file mode 100644 index 00000000000..ebcd66ffa07 --- /dev/null +++ b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/all_gather_decomposer.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; +using AllGatherDecomposerTest = HloTestBase; + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGather) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer(/*partition_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossPartitionAllGather) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0}}, channel_id=1, + dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer(/*partition_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::PartitionId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithTrivialGroup) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0,1,2,3}}, + dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer(/*partition_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), op::Constant(), + op::Multiply(op::ReplicaId(), op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithSubgroups) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), + replica_groups={{2,1,0,3}, {4,6,7,5}}, dimensions={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer(/*partition_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + auto id = + AllOf(op::Shape("s32[]"), + op::Reshape(op::DynamicSlice(op::Constant(), op::ReplicaId()))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), + op::Constant(), op::Multiply(id, op::Constant())))); +} + +TEST_F(AllGatherDecomposerTest, CrossReplicaAllGatherWithSubgroupsGlobalIds) { + const string module_str = R"( +HloModule module + +ENTRY entry { + param0 = f32[10,20] parameter(0) + ROOT ag = f32[10,80] all-gather(param0), + replica_groups={{2,1,0,3}, {4,6,7,5}}, dimensions={1}, channel_id=1, + use_global_device_ids=true +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((module_str))); + AllGatherDecomposer decomposer(/*partition_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + LOG(ERROR) << module->ToString(); + auto global_id = + op::Add(op::PartitionId(), op::Multiply(op::ReplicaId(), op::Constant())); + auto id = AllOf(op::Shape("s32[]"), + op::Reshape(op::DynamicSlice(op::Constant(), global_id))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(op::Constant()), op::Parameter(0), + op::Constant(), op::Multiply(id, op::Constant())))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index abb695fa486..30d764225c2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -79,6 +79,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( const HloInstruction& hlo, int64 operand_index) { switch (hlo.opcode()) { case HloOpcode::kAbs: + case HloOpcode::kAllGather: case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 8c76e912011..ce9c8a4ea62 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -91,6 +91,7 @@ CompileOnlyService::CompileAheadOfTime( TF_RETURN_IF_ERROR(options.static_device_assignment().Serialize( execution_options.mutable_device_assignment())); } + execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning()); for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_host_program_shape()); *execution_options.mutable_shape_with_output_layout() = diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index cf646159a38..57b24e372e6 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -76,6 +76,7 @@ class AotCompilationOptions { virtual int64 replica_count() const { return 0; } virtual int64 num_cores() const { return 0; } + virtual bool use_spmd_partitioning() const { return false; } // Optional allocator that may be used for allocating temp space on the device // during compilation. diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc new file mode 100644 index 00000000000..eecdcc851e9 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -0,0 +1,483 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_code_motion.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" + +namespace xla { + +namespace { + +struct ConditionalBoundary { + ConditionalBoundary(HloInstruction* op, int64 op_index, HloInstruction* usr) + : operand(op), operand_index(op_index), user(usr) {} + // `operand` is one of `user`'s operand. + + // Instruction that remains in the conditional but one of its user + // is moved out of conditonal. + HloInstruction* operand; + // operand_index for `operand` in the `user`. + int64 operand_index; + // Instruction that moved out of conditional. + HloInstruction* user; +}; + +// Visit the root instructions to its operands follow BFS. +// Will visit an instructions after all its users have been visited. Parameters +// are not visited. +class BranchVisitor { + public: + explicit BranchVisitor(const HloComputation* branch_computation) { + HloInstruction* root_inst = branch_computation->root_instruction(); + worklist_.push_back(root_inst); + visited_.insert(root_inst); + for (auto parameter_inst : branch_computation->parameter_instructions()) { + parameter_instructions_.insert(parameter_inst); + } + } + // Get next intruction to visit. + HloInstruction* GetNextInstruction() { + if (!worklist_.empty()) { + HloInstruction* inst = worklist_.front(); + worklist_.pop_front(); + return inst; + } + return nullptr; + } + + // Add operands of one instruction to worklist for further visit. + void AddInstructionOperands(HloInstruction* inst) { + int64 operand_count = inst->operand_count(); + for (int i = 0; i < operand_count; i++) { + HloInstruction* operand = inst->mutable_operand(i); + if (ContainsKey(visited_, operand)) { + continue; + } + bool all_user_visited = std::all_of( + operand->users().begin(), operand->users().end(), + [&](HloInstruction* user) { return ContainsKey(visited_, user); }); + + if (!all_user_visited) { + continue; + } + // Do not visit parameter_instructions. + if (ContainsKey(parameter_instructions_, operand)) { + // Add the operand and this instruction to the boundaries. + boundaries_.emplace_back(operand, i, inst); + continue; + } + + worklist_.push_back(operand); + visited_.insert(operand); + } + } + + // Add instruction and its users to conditional boundaries. + void AddInstructionToBoundary(HloInstruction* inst) { + for (auto user : inst->users()) { + boundaries_.emplace_back(inst, user->operand_index(inst), user); + } + } + + // Add instruction to the to be removed instructions set and vector. + void AddInstructionToHoist(HloInstruction* inst) { + instructions_to_hoist_set_.insert(inst); + instructions_to_hoist_.emplace_back(inst); + } + + // If visitor has next instruction to visit. + bool HasNextInstruction() const { return !worklist_.empty(); } + + // If there is no hoist intruction. + int64 HoistInstructionSize() { return instructions_to_hoist_.size(); } + + // Get boundaries of this branch. + const std::vector& boundaries() const { + return boundaries_; + } + + // Get instructions to hoist in this branch. + const std::vector& instructions_to_hoist() const { + return instructions_to_hoist_; + } + + // Get hoist instruction set in this branch. + const std::unordered_set& instructions_to_hoist_set() const { + return instructions_to_hoist_set_; + } + + private: + // worklist is the deque that contains instructions to be visited. + std::deque worklist_; + + // instructions that has been visited. + std::unordered_set visited_; + + // parameter instructions of the branch. + std::unordered_set parameter_instructions_; + + // Boundaries contains the set of instructions that its operand is within + // conditional but it can be hoist out of conditional. + std::vector boundaries_; + + // Instructions to hoist. + std::unordered_set instructions_to_hoist_set_; + + // Instructions to hoist, the order within this vector is BFS and + // an instruction's order will always be after its users. + std::vector instructions_to_hoist_; +}; + +// Returns true if `instruction` is worth hoisting out. +bool WorthHoisting(HloInstruction* instruction) { + for (const auto* operand : instruction->operands()) { + // Only move out instructions that won't share the same operand + // to avoid copy of the operand. + if (operand->user_count() > 1) { + return false; + } + } + switch (instruction->opcode()) { + case HloOpcode::kConvert: + // If Convert is after AllReduce, it is worth moving out AllReduce out + // of conditional for AR/CRS combine. If Convert is after other ops such + // as Dot or Convolutional, it is better to keep convert within + // conditional so that convert can be fused with Dot or Convolutional. + // + // TODO(b/154283721): figure out the scenario when convert can be fused + // with AllReduce out of conditional. + if (instruction->operand(0)->opcode() == HloOpcode::kAllReduce) { + return true; + } + return false; + case HloOpcode::kAllReduce: + case HloOpcode::kAdd: + case HloOpcode::kConstant: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + case HloOpcode::kDivide: + case HloOpcode::kTuple: + case HloOpcode::kGetTupleElement: + return true; + default: + return false; + } +} + +// Compare if the instructions to be visited at each branches are identical. +bool InstructionWithinBranchIdentical( + const std::vector& instructions, bool is_layout_senstive) { + // Identical includes the shape of each operands are equal. + auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) { + bool eq_operands = is_layout_senstive + ? ShapeUtil::Equal(a->shape(), b->shape()) + : ShapeUtil::Compatible(a->shape(), b->shape()); + return eq_operands; + }; + + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + + if (instructions[0] == nullptr) { + return false; + } + + if (instructions[0]->IsCrossModuleAllReduce()) { + return std::all_of( + instructions.begin(), instructions.end(), + [&](HloInstruction* instruction) { + if (!instruction->IsCrossModuleAllReduce()) { + return false; + } + auto old_channel_id = instruction->channel_id(); + instruction->set_channel_id(instructions[0]->channel_id()); + bool eq_instructions = instructions[0]->Identical( + *instruction, eq_operand, eq_computations, is_layout_senstive); + instruction->set_channel_id(old_channel_id); + return eq_instructions; + }); + } + + return std::all_of(instructions.begin(), instructions.end(), + [&](HloInstruction* instruction) { + return instructions[0]->Identical( + *instruction, eq_operand, eq_computations, + is_layout_senstive); + }); +} + +// Returns if all the visitors/branches has next instruction to visit. +bool HasNextInstruction(const std::vector& visitors) { + bool has_next = true; + for (const auto& visitor : visitors) { + has_next &= visitor.HasNextInstruction(); + } + return has_next; +} + +// Create tuple element as the new root of the branch. The tuple will contain +// the operands that can't move out of conditional but its user will be moved +// out of conditional. +HloInstruction* CreateNewRoot( + const std::vector& boundaries, + const std::unordered_set& instructions_to_hoist_set, + HloComputation* computation) { + std::vector elements; + elements.reserve(boundaries.size()); + for (auto boundary : boundaries) { + if (ContainsKey(instructions_to_hoist_set, boundary.user)) { + elements.push_back(boundary.operand); + } + } + return computation->AddInstruction(HloInstruction::CreateTuple(elements)); +} + +// Copy identical instructions within conditional outside of conditional. +void CopyIdenticalInstructionsOutOfConditional( + const std::vector& instructions_to_hoist, + HloComputation* conditional_parent, + absl::flat_hash_map* + hoisted_instructions) { + int64 instructions_size = instructions_to_hoist.size(); + // Visit the operands before its users and copy it, so that the copied + // user will point to the correct operand. + for (int64 i = instructions_size - 1; i >= 0; i--) { + HloInstruction* old_instruction = instructions_to_hoist[i]; + auto get_new_operand = [&](HloInstruction* old_operand) { + // If the operand can't be found in `instructions_to_hoist`, this + // operand will be in the `boundaries`, GetTupleElement instructions + // will be added later to replace this operand. + if (!ContainsKey(*hoisted_instructions, old_operand)) { + return old_operand; + } + return FindOrDie(*hoisted_instructions, old_operand); + }; + + absl::InlinedVector new_operands; + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); + + HloInstruction* new_instruction = conditional_parent->AddInstruction( + old_instruction->CloneWithNewOperands(old_instruction->shape(), + new_operands)); + // Maps the instruction outside of conditional to the instruction + // inside of the conditional. + InsertOrDie(hoisted_instructions, old_instruction, new_instruction); + } +} + +// If there are instructions to hoist, the root of the conditional must be +// moved out. Change the users of the conditional to the hoisted instruction +// of the new root. +Status ChangeConditionalUsers( + HloInstruction* conditional, HloInstruction* old_root, + const absl::flat_hash_map& + hoisted_instructions) { + HloInstruction* new_root = FindOrDie(hoisted_instructions, old_root); + TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWith(new_root)); + return Status::OK(); +} + +// Insert GetTupleElement before the instructions whose operands might still +// be within the conditional. +Status CreateGetTupleElementAfterConditional( + const std::vector& boundaries, + const std::unordered_set& instructions_to_hoist_set, + const absl::flat_hash_map& + hoisted_instructions, + HloInstruction* conditional, HloComputation* computation) { + int boundary_instruction_size = boundaries.size(); + + // Inserts GetTupleElement before the boundary instructions. + for (int i = 0; i < boundary_instruction_size; i++) { + HloInstruction* gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + boundaries[i].operand->shape(), conditional, i)); + + HloInstruction* new_instruction = + FindOrDie(hoisted_instructions, boundaries[i].user); + TF_RETURN_IF_ERROR( + new_instruction->ReplaceOperandWith(boundaries[i].operand_index, gte)); + } + return Status::OK(); +} + +// Remove instructions to be hoisted out of the branch computation. +Status RemoveInstructionFromComputation( + const std::vector& instructions_to_hoist, + HloComputation* branch) { + // Will visit the instructions after its users. + for (auto* instruction : instructions_to_hoist) { + TF_RETURN_IF_ERROR(branch->RemoveInstruction(instruction)); + } + return Status::OK(); +} + +// Hoist identical ops out of the conditional. The definition of identical +// are the shape of the operands are identical and their properties are +// identical. Will start from the root instruction of each branch and get +// the identical ops to hoist. +StatusOr MergeIdenticalElements(HloInstruction* conditional, + bool is_layout_sensitive) { + int branch_count = conditional->branch_count(); + if (branch_count <= 0) { + return false; + } + + std::vector visitors; + visitors.reserve(branch_count); + // Visit instructions from the root instruction to the operands using BFS. + for (int i = 0; i < branch_count; i++) { + visitors.emplace_back(BranchVisitor(conditional->branch_computation(i))); + } + + // The instructions to be visited within each branch. + std::vector front_instructions(branch_count); + + while (HasNextInstruction(visitors)) { + for (int i = 0; i < branch_count; i++) { + front_instructions[i] = visitors[i].GetNextInstruction(); + } + // If two instructions has the same shape, opcode and its operands has the + // same shape, then this instruction can be moved out of conditional. + if (WorthHoisting(front_instructions[0]) && + InstructionWithinBranchIdentical(front_instructions, + is_layout_sensitive)) { + for (int i = 0; i < branch_count; i++) { + visitors[i].AddInstructionOperands(front_instructions[i]); + visitors[i].AddInstructionToHoist(front_instructions[i]); + } + } else { + for (int i = 0; i < branch_count; i++) { + // If the ops are not identical, these ops and its users will + // be in the boundaries` of the conditional. These ops will be stayed + // within the conditional, but one its only user will be moved out + // of conditional. + visitors[i].AddInstructionToBoundary(front_instructions[i]); + } + } + } + + if (visitors[0].HoistInstructionSize() <= 1) { + return false; + } + + HloInstruction* old_root = + conditional->branch_computation(0)->root_instruction(); + HloComputation* conditional_parent = conditional->parent(); + // Maps instructions in the conditional body to instructions hoisted outside + // the conditional that compute the same value. + absl::flat_hash_map hoisted_instructions; + // Copy identical instructions out of the conditional. + CopyIdenticalInstructionsOutOfConditional(visitors[0].instructions_to_hoist(), + conditional_parent, + &hoisted_instructions); + // If there are instructions to hoist, the root of the conditional must be + // moved out. Change the users of the conditional to the hoisted instruction + // of the new root. + TF_RETURN_IF_ERROR( + ChangeConditionalUsers(conditional, old_root, hoisted_instructions)); + + // Create tuple element within each branch and set it as root. + for (int i = 0; i < branch_count; i++) { + HloInstruction* tuple = CreateNewRoot( + visitors[i].boundaries(), visitors[i].instructions_to_hoist_set(), + conditional->branch_computation(i)); + conditional->branch_computation(i)->set_root_instruction(tuple, true); + } + // Changes conditional instruction shape to the shape of the new root. + *conditional->mutable_shape() = + conditional->branch_computation(0)->root_instruction()->shape(); + + // Insert GetTupleElement before the instructions whose operands might still + // be within the conditional. + TF_RETURN_IF_ERROR(CreateGetTupleElementAfterConditional( + visitors[0].boundaries(), visitors[0].instructions_to_hoist_set(), + hoisted_instructions, conditional, conditional_parent)); + + // Remove hoist instructions from the branches. + for (int i = 0; i < branch_count; i++) { + TF_RETURN_IF_ERROR( + RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(), + conditional->branch_computation(i))); + } + + return true; +} + +} // namespace + +StatusOr ConditionalCodeMotion::Run(HloModule* module) { + bool changed = false; + + // Gather all the conditional ops in our module. We do this ahead of time so + // we don't have to worry about mutating the lists of computations or + // instructions as we iterate. + std::vector conditional_ops; + for (auto* comp : module->MakeComputationPostOrder()) { + for (auto* instr : comp->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } + } + } + + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements( + conditional_op, is_layout_sensitive_)); + changed |= result; + } + + if (changed) { + HloPassPipeline subpipeline("after_conditional_code_motion"); + subpipeline.AddPass(); + subpipeline.AddPass(); + TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); + changed |= cleanup_changed; + } + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.h b/tensorflow/compiler/xla/service/conditional_code_motion.h new file mode 100644 index 00000000000..1197a8b3620 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_code_motion.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that moves identical ops out of conditional. +// - The definition of identical are the shape of the operands are identical +// and their properties are identical. +// - Currently, only some types of instructions is supported. +// TODO(b/154283721): relax non-sharable operand constraint and avoid copies in +// the new root. +// - Only the identical ops that won't share operands with other ops will +// be moved out of conditional. +class ConditionalCodeMotion : public HloModulePass { + public: + // If is_layout_sensitive is true, then the hoist process preserves layout + // during identical comparison. Otherwise, layout is ignored. + explicit ConditionalCodeMotion(bool is_layout_sensitive = true) + : is_layout_sensitive_(is_layout_sensitive) {} + absl::string_view name() const override { return "conditional-code-motion"; } + StatusOr Run(HloModule* module) override; + + private: + const bool is_layout_sensitive_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc new file mode 100644 index 00000000000..4a52303a42a --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -0,0 +1,413 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_code_motion.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ConditionalCodeMotionTest = HloTestBase; +namespace op = xla::testing::opcode_matchers; + +TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + +on_true { + %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) + %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493) + ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) + %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} + ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) + conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 + ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + +TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[]) parameter(0) + get-tuple-element.1 = f32[] get-tuple-element(arg_tuple.1), index=0 + constant.1 = f32[] constant(1) + constant.2 = f32[] constant(2) + constant.3 = f32[] constant(3) + constant.4 = f32[] constant(4) + constant.5 = f32[] constant(5) + add.1 = f32[] add(get-tuple-element.1, constant.1) + add.2 = f32[] add(add.1, constant.2) + add.3 = f32[] add(add.1, constant.3) + add.4 = f32[] add(add.3, constant.5) + multiply.1 = f32[] multiply(add.2, constant.4) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4) +} + +on_false { + arg_tuple.2 = (f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0 + constant.6 = f32[] constant(1) + constant.7 = f32[] constant(2) + constant.8 = f32[] constant(3) + constant.9 = f32[] constant(4) + constant.10 = f32[] constant(5) + add.4 = f32[] add(get-tuple-element.2, constant.6) + sub.1 = f32[] subtract(add.4, constant.7) + add.5 = f32[] add(add.4, constant.8) + add.6 = f32[] add(add.5, constant.10) + multiply.2 = f32[] multiply(sub.1, constant.9) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.2, add.6) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[]) parameter(1) + tuple.2 = (f32[]) parameter(2) + conditional = (f32[], f32[]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false + get-first-index = f32[] get-tuple-element(conditional), index=0 + get-second-index = f32[] get-tuple-element(conditional), index=1 + ROOT result = (f32[], f32[]) tuple(get-first-index, get-second-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 9); + + // Check only one add and multiply is moved out. + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Tuple( + op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); +} + +TEST_F(ConditionalCodeMotionTest, ConditionalRootElementChanged) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[]) parameter(0) + get-tuple-element.1 = f32[] get-tuple-element(arg_tuple.1), index=0 + constant.1 = f32[] constant(1) + constant.2 = f32[] constant(2) + add.1 = f32[] add(get-tuple-element.1, constant.1) + add.2 = f32[] add(get-tuple-element.1, constant.2) + add.3 = f32[] add(add.1, add.2) + ROOT tuple.3 = (f32[]) tuple(add.3) +} + +on_false { + arg_tuple.2 = (f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0 + constant.3 = f32[] constant(1) + constant.4 = f32[] constant(2) + add.4 = f32[] add(get-tuple-element.2, constant.3) + add.5 = f32[] add(get-tuple-element.2, constant.4) + add.6 = f32[] add(add.4, add.5) + ROOT tuple.4 = (f32[]) tuple(add.6) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[]) parameter(1) + tuple.2 = (f32[]) parameter(2) + conditional = (f32[]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false + get-first-index = f32[] get-tuple-element(conditional), index=0 + ROOT result = (f32[]) tuple(get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 7); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 7); + + // add.3 in on_true will be moved out, add.1 and add.2 will be in condtional + // root. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}))); +} + +TEST_F(ConditionalCodeMotionTest, ConditionalIsRootInstruction) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[]) parameter(0) + get-tuple-element.1 = f32[] get-tuple-element(arg_tuple.1), index=0 + constant.1 = f32[] constant(1) + constant.2 = f32[] constant(2) + constant.3 = f32[] constant(3) + constant.4 = f32[] constant(4) + constant.5 = f32[] constant(5) + add.1 = f32[] add(get-tuple-element.1, constant.1) + add.2 = f32[] add(add.1, constant.2) + add.3 = f32[] add(add.1, constant.3) + add.4 = f32[] add(add.3, constant.5) + multiply.1 = f32[] multiply(add.2, constant.4) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4) +} + +on_false { + arg_tuple.2 = (f32[]) parameter(0) + get-tuple-element.2 = f32[] get-tuple-element(arg_tuple.2), index=0 + constant.6 = f32[] constant(1) + constant.7 = f32[] constant(2) + constant.8 = f32[] constant(3) + constant.9 = f32[] constant(4) + constant.10 = f32[] constant(5) + add.4 = f32[] add(get-tuple-element.2, constant.6) + sub.1 = f32[] subtract(add.4, constant.7) + add.5 = f32[] add(add.4, constant.8) + add.6 = f32[] add(add.5, constant.10) + multiply.2 = f32[] multiply(sub.1, constant.9) + ROOT tuple.6 = (f32[], f32[]) tuple(multiply.2, add.6) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[]) parameter(1) + tuple.2 = (f32[]) parameter(2) + ROOT conditional = (f32[], f32[]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 9); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 9); + + // Check only one add and multiply is moved out. + // add.3 and add.5 can't be moved out because they share operands with + // other instructions. + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Tuple( + op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); +} + +TEST_F(ConditionalCodeMotionTest, LayoutMisMatchCannotMovedOut) { + absl::string_view hlo_string = + R"( +HloModule LayoutMisMatchCannotMovedOut + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + %arg_tuple.1 = (bf16[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = bf16[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %all-reduce.1 = bf16[93184,4]{1,0} + all-reduce(bf16[93184,4]{1,0} %get-tuple-element.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64 + %convert.2894 = f32[93184,4]{1,0} convert(bf16[93184, 4]{1,0} %all-reduce.1) + ROOT %tuple.1 = (f32[93184,4]{1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (bf16[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = bf16[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %copy.1 = bf16[93184,4]{0,1} copy(bf16[93184,4]{1,0} %get-tuple-element.3) + %all-reduce.2 = bf16[93184,4]{0, 1} + all-reduce(bf16[93184,4]{0, 1} %copy.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181 + %convert.3604 = f32[93184,4]{0,1} convert(bf16[93184,4]{0,1} %all-reduce.2) + ROOT %tuple.2 = (f32[93184,4]{0,1}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (bf16[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (bf16[93184,4]{1,0}) parameter(2) + conditional = (f32[93184,4]{1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = f32[93184,4]{1,0} get-tuple-element(conditional), index=0 + ROOT result = (f32[93184,4]{1,0}) tuple(get-first-index) +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + +TEST_F(ConditionalCodeMotionTest, MoveCrossModuleAllReduceOut) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0) + get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0 + get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1 + convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128] + get-tuple-element.11, bf16[2,52,168,128] + get-tuple-element.12), window={size=52x168 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.1 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.1), + channel_id=188, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.64, metadata={op_type="Conv2DBackpropFilter" + op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"} + convert.1 = f32[3,3,128,128] convert(bf16[3,3,128,128] %all-reduce.1), + metadata={op_type="Cast" op_name="Cast_15"} + ROOT tuple.1 = (f32[3,3,128,128]) tuple(convert.1) +} + +on_false { + arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0) + get-tuple-element.21 = bf16[2,86,104,128] + get-tuple-element(arg_tuple.2), index=0 + get-tuple-element.22 = bf16[2,84,104,128] + get-tuple-element(arg_tuple.2), index=1 + convolution.2 = bf16[3,3,128,128] + convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128] + get-tuple-element.22), window={size=84x104 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + all-reduce.2 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %convolution.2), + channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181, metadata={op_type="Conv2DBackpropFilter" + op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"} + convert.2 = f32[3,3,128,128] + convert(bf16[3,3,128,128] %all-reduce.2), + metadata={op_type="Cast" op_name="Cast_15"} + ROOT tuple.2 = (f32[3,3,128,128]) tuple(convert.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1) + arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2) + conditional = (f32[3,3,128,128]) + conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true, + false_computation=on_false + get-first-index = f32[3,3,128,128] + get-tuple-element(conditional), index=0 + ROOT result = (f32[3,3,128,128]) tuple(get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + // Checks if conditional shape has changed. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( + BF16, {3, 3, 128, 128})}))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce( + op::GetTupleElement(op::Conditional())))))); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e8e1f044704..2f432cd9356 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -35,6 +35,7 @@ filegroup( srcs = [ "runtime_fp16.cc", "runtime_key_value_sort.cc", + "runtime_pow.cc", "runtime_single_threaded_conv2d.cc", "runtime_single_threaded_fft.cc", "runtime_single_threaded_matmul.cc", @@ -49,6 +50,7 @@ filegroup( "runtime_fft_impl.h", "runtime_fp16.h", "runtime_key_value_sort.h", + "runtime_pow.h", "runtime_single_threaded_conv2d.h", "runtime_single_threaded_fft.h", "runtime_single_threaded_matmul.h", @@ -144,6 +146,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_padder", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -204,6 +207,7 @@ cc_library( ":cpu_runtime", ":orc_jit_memory_mapper", ":runtime_fp16", + ":runtime_pow", ":runtime_conv2d", ":runtime_conv2d_mkl", ":runtime_fft", @@ -250,6 +254,21 @@ cc_library( ], ) +cc_library( + name = "runtime_pow", + srcs = [ + "runtime_pow.cc", + ], + hdrs = [ + "runtime_pow.h", + ], + copts = runtime_copts(), + deps = [ + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "cpu_executable", srcs = ["cpu_executable.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 5e536d362d9..a21ace0d8b2 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -198,11 +198,6 @@ void CompilerFunctor::AddTargetInfoPasses( target_library_info_impl->addVectorizableFunctions( VectorFunctionsForTargetLibraryInfoImpl()); - // TODO(b/136651482): Disable pow(f) so LLVM doesn't transform it into powi. - // It would be better to provide our own powi. - target_library_info_impl->setUnavailable(llvm::LibFunc_pow); - target_library_info_impl->setUnavailable(llvm::LibFunc_powf); - passes->add( new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl)); passes->add(createTargetTransformInfoWrapperPass( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index b04237138e8..fe769bbdd2a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -72,6 +72,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" +#include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -239,7 +240,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // Expand random number generation. pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); @@ -273,6 +273,13 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/false); + pipeline.AddPass(); + pipeline.AddPass( + /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(target_machine_features); { auto& pass = @@ -281,12 +288,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*allow_mixed_precision=*/false); pass.AddPass(); - pass.AddPass(); - pass.AddPass( - /*rewrite_training_op=*/true, - /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true); - pipeline.AddPass(); AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(false); pass.AddPass(options); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 8c1ae0179c0..f031daecb1f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -363,7 +363,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( if (shape.IsOpaque()) { return sizeof(void*); } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + if (shape.is_static() || shape.IsTuple()) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + // Each dynamic dimension size is represented as a S32. + int64 metadata_size = sizeof(int32) * shape.dimensions_size(); + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size; } const InstructionValueSet& CpuExecutable::GetRootValueSet() const { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index e21ca01c803..05364a4492b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -109,24 +109,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { - case HloOpcode::kMap: - return [this, hlo, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - std::vector operands; - for (int i = 0; i < hlo->operand_count(); i++) { - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(i))(index)); - operands.push_back(operand_value); - } - return ir_emitter_->EmitElementalMap(*Cast(hlo), - operands, llvm_ir::IrName(hlo)); - }; - case HloOpcode::kReduceWindow: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { - return ir_emitter_->EmitElementalReduceWindow( - Cast(hlo), - operand_to_generator.at(hlo->operand(0)), index); - }; case HloOpcode::kConvolution: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { return ir_emitter_->EmitElementalConvolution( @@ -134,22 +116,6 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(hlo->operand(0)), operand_to_generator.at(hlo->operand(1)), index); }; - case HloOpcode::kReduce: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { - auto reduce_instr = Cast(hlo); - std::vector input_generators; - for (const HloInstruction* instr : reduce_instr->inputs()) { - input_generators.push_back(operand_to_generator.at(instr)); - } - - std::vector initial_value_generators; - for (const HloInstruction* instr : reduce_instr->init_values()) { - initial_value_generators.push_back(operand_to_generator.at(instr)); - } - return ir_emitter_->EmitElementalReduce( - reduce_instr, std::move(input_generators), - std::move(initial_value_generators), index); - }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index e3fba9306b7..5c9f6677ab3 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -44,6 +44,12 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) override { + return ir_emitter_->EmitThreadLocalCall(callee, parameters, name); + } + IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index c19fa779b60..5a4c6250293 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include + #include #include #include @@ -570,25 +571,9 @@ Status IrEmitter::HandleSort(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); Shape keys_shape = sort->keys()->shape(); PrimitiveType keys_type = keys_shape.element_type(); - switch (keys_type) { - case PRED: - case S8: - case U8: - case S16: - case U16: - case BF16: - case F16: - case S32: - case U32: - case F32: - case S64: - case U64: - case F64: - break; - default: - return Unimplemented( - "Element type %s not supported in the Sort op on CPU.", - PrimitiveType_Name(keys_type)); + if (!primitive_util::IsArrayType(keys_type)) { + return Unimplemented("Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); } std::vector destination_addresses(sort->operand_count()); for (int64 i = 0; i < sort->operand_count(); ++i) { @@ -695,101 +680,6 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -llvm::Value* IrEmitter::EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, absl::string_view name) { - return EmitScalarReturningThreadLocalCall(*map_instr.to_apply(), - elemental_operands, name); -} - -StatusOr IrEmitter::EmitElementalReduceWindow( - const HloReduceWindowInstruction* reduce_window, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::IrArray::Index& index) { - const HloInstruction* operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - - // We fold inputs into the accumulator and initialize it to - // the initial value on the reduce_window. - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accumulator_address", &b_, - MinimumAlignmentForPrimitiveType(operand_element_type)); - Store(Load(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); - - llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - std::vector input_multi_index(index.size()); - llvm::Value* in_bounds_condition = nullptr; - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* strided_index = - NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_multi_index[i] = NSWSub( - NSWAdd(strided_index, - NSWMul(window_index[i], - b_.getInt64(window.dimensions(i).window_dilation()))), - b_.getInt64(window.dimensions(i).padding_low())); - - // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = - ICmpEQ(SRem(input_multi_index[i], - b_.getInt64(window.dimensions(i).base_dilation())), - b_.getInt64(0)); - if (in_bounds_condition == nullptr) { - in_bounds_condition = dilation_condition; - } else { - in_bounds_condition = And(in_bounds_condition, dilation_condition); - } - - // Apply base dilation to the index. - input_multi_index[i] = - SDiv(input_multi_index[i], - b_.getInt64(window.dimensions(i).base_dilation())); - - // We need to check if 0 <= input_multi_index[i] < bound, as otherwise we - // are in the padding so that we can skip the computation. That is - // equivalent to input_multi_index[i] < bound as an *unsigned* comparison, - // since a negative value will wrap to a large positive value. - llvm::Value* index_condition = - ICmpULT(input_multi_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - if (in_bounds_condition == nullptr) { - in_bounds_condition = index_condition; - } else { - in_bounds_condition = And(in_bounds_condition, index_condition); - } - } - CHECK(in_bounds_condition != nullptr); - - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); - SetToFirstInsertPoint(if_data.true_block, &b_); - - // We are not in the padding, so carry out the computation. - llvm_ir::IrArray::Index input_index(input_multi_index, operand->shape(), - b_.getInt64Ty()); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_value, - input_generator(input_index)); - llvm::Value* result = EmitScalarReturningThreadLocalCall( - *reduce_window->to_apply(), {Load(accumulator_address), input_value}, - "reducer_function"); - Store(result, accumulator_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(accumulator_address); -} - Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // Pseudo code for reduce window: // @@ -2099,108 +1989,6 @@ StatusOr IrEmitter::EmitVectorizedReduce( return true; } -StatusOr IrEmitter::EmitElementalReduce( - const HloReduceInstruction* reduce, - std::vector input_generators, - std::vector initial_value_generators, - const llvm_ir::IrArray::Index& index) { - const Shape& out_shape = reduce->shape(); - bool is_variadic = !out_shape.IsArray(); - int accumulators_count = 1; - if (is_variadic) { - CHECK(out_shape.IsTuple()); - accumulators_count = out_shape.tuple_shapes_size(); - } - - absl::Span reduced_dimensions(reduce->dimensions()); - - std::vector accumulator_addrs; - std::vector accumulator_types; - for (int i = 0; i < accumulators_count; i++) { - const Shape& element_shape = - is_variadic ? out_shape.tuple_shapes(i) : out_shape; - PrimitiveType accumulator_type = element_shape.element_type(); - llvm::Type* accumulator_llvm_type = - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); - accumulator_types.push_back(accumulator_llvm_type); - - // Initialize an accumulator with init_value. - llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - accumulator_llvm_type, "accumulator_" + std::to_string(i), &b_, - MinimumAlignmentForPrimitiveType(accumulator_type)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_value, - initial_value_generators[i](llvm_ir::IrArray::Index(index.GetType()))); - Store(init_value, accumulator_addr); - accumulator_addrs.push_back(accumulator_addr); - } - - // The enclosing loops go over all the target elements. Now we have to compute - // the actual target element. For this, we build a new loop nest to iterate - // over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction Value*s - // are placed for each dimension in dimensions, and all the rest are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); - const HloInstruction* arg = reduce->operand(0); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); - - // Build a full index for the input argument, using input_multi_index as the - // base. In input_multi_index only the reduction dimensions are filled in. We - // fill in the rest of the dimensions with induction Value*s taken from - // 'index' which iterates over the target array. See the high-level - // description in the XLA documentation for details. - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = *it++; - } - } - CHECK(index.end() == it); - llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), - b_.getInt64Ty()); - - std::vector reduction_operands; - for (llvm::Value* accum : accumulator_addrs) { - llvm::Value* accum_value = Load(accum); - reduction_operands.push_back(accum_value); - } - - for (int i = 0; i < accumulators_count; i++) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, - input_generators[i](input_index)); - reduction_operands.push_back(input_element); - } - - std::vector results = EmitThreadLocalCall( - *reduce->to_apply(), reduction_operands, "reduce_function"); - - CHECK(results.size() == accumulators_count); - for (int i = 0; i < accumulators_count; i++) { - Store(results[i], accumulator_addrs[i]); - } - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - - if (is_variadic) { - // Emit a structure, as that what the LoopEmitter expects. - llvm::Value* returned_structure = llvm::UndefValue::get( - llvm::StructType::get(b_.getContext(), accumulator_types)); - for (int i = 0; i < accumulators_count; i++) { - llvm::Value* accumulator_value = Load(accumulator_addrs[i]); - returned_structure = - b_.CreateInsertValue(returned_structure, accumulator_value, i); - } - return returned_structure; - } else { - CHECK_EQ(accumulator_addrs.size(), 1); - return Load(accumulator_addrs[0]); - } -} - Status IrEmitter::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); @@ -2554,7 +2342,95 @@ Status IrEmitter::HandleCall(HloInstruction* call) { return Status::OK(); } +Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { + // TODO(jackcao): Generalize this to generic llvm emitter. + TF_RET_CHECK(hlo->shape().rank() == 1); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + const int64 dim_index = i - 1; + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i)); + llvm::LoadInst* dim_size = b_.CreateLoad(source_buffer, "dim_size"); + llvm::Value* dest_buffer = GetEmittedValueFor(hlo); + llvm::Value* raw_buffer = + b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo()); + + int32 raw_data_size = + ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape())); + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + b_.CreateStore(dim_size, + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + } + + return EmitTargetElementLoop(hlo, + [=](const llvm_ir::IrArray::Index& dest_index) { + // TODO(jackcao): Properly linearize dest_index + // and delinearize to source index. + return GetIrArrayFor(hlo->operand(0)) + .EmitReadArrayElement(dest_index, &b_); + }); +} + +Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { + // TODO(jackcao): Generalize this to generic llvm emitter. + TF_RET_CHECK(hlo->operand(0)->shape().rank() == 1); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(hlo, {0})); + const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0}); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); + llvm_ir::IrArray data_array(data_address, data_shape); + TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter( + [=](const llvm_ir::IrArray::Index& dest_index) { + // TODO(jackcao): Properly linearize dest_index and + // delinearize to source index. + return GetIrArrayFor(hlo->operand(0)) + .EmitReadArrayElement(dest_index, &b_); + }, + llvm_ir::IrArray(data_address, data_shape), &b_) + .EmitLoop(IrName(hlo))); + std::vector tuple_operand_ptrs; + tuple_operand_ptrs.push_back(data_array.GetBasePointer()); + + // PadToStatic has a dynamic tensor as input and variadic size of outputs: + // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... ) + // Dynamic dimension sizes starts from output index 1. + for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) { + // Read from the metadata section of the dynamic input (operand 0). + const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i}); + TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice, + assignment_.GetUniqueSlice(hlo, {i})); + llvm::Value* dest_dim_size_address = + EmitBufferPointer(dim_size_slice, data_shape); + const int64 dim_index = i - 1; + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0)); + llvm::Value* raw_buffer = + b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo()); + int32 raw_data_size = ShapeUtil::ByteSizeOf( + ShapeUtil::MakeStaticShape(hlo->operand(0)->shape())); + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + llvm::Value* dim_size = b_.CreateLoad( + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + b_.CreateStore(dim_size, b_.CreateBitCast(dest_dim_size_address, + b_.getInt32Ty()->getPointerTo())); + tuple_operand_ptrs.push_back(dest_dim_size_address); + } + + // Emit static tensor and dynamic sizes as one tuple. + llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_); + return Status::OK(); +} + Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { + if (custom_call->custom_call_target() == "PadToStatic") { + return HandlePadToStatic(custom_call); + } + if (custom_call->custom_call_target() == "SliceToDynamic") { + return HandleSliceToDynamic(custom_call); + } absl::Span operands(custom_call->operands()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index cc5aa3f37fc..9b0d11e9f3f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -58,6 +58,8 @@ namespace cpu { // functions. class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { + friend class CpuElementalIrEmitter; + public: using GeneratorForOperandIrArrays = std::function()>; @@ -113,28 +115,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); - // Emit code to map one element according to `map_instr`. - llvm::Value* EmitElementalMap( - const HloMapInstruction& map_instr, - absl::Span elemental_operands, - absl::string_view name); - // Emit code to emit the element at `index` for a reduce window instruction. - StatusOr EmitElementalReduceWindow( - const HloReduceWindowInstruction* reduce_window, - const llvm_ir::ElementGenerator& input_generator, - const llvm_ir::IrArray::Index& index); // Emit code to emit the element at `index` for a convolution instruction. StatusOr EmitElementalConvolution( const HloConvolutionInstruction* convolution, const llvm_ir::ElementGenerator& input_generator, const llvm_ir::ElementGenerator& kernel_generator, const llvm_ir::IrArray::Index& index); - // Emit code to emit the element at `index` for a reduce instruction. - StatusOr EmitElementalReduce( - const HloReduceInstruction* reduce, - std::vector input_generators, - std::vector initial_value_generator, - const llvm_ir::IrArray::Index& index); protected: // @@ -197,6 +183,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, } private: + Status HandleSliceToDynamic(HloInstruction* hlo); + Status HandlePadToStatic(HloInstruction* hlo); Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_pow.cc b/tensorflow/compiler/xla/service/cpu/runtime_pow.cc new file mode 100644 index 00000000000..08308b4ce57 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_pow.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_pow.h" + +#include "tensorflow/core/platform/macros.h" + +template +static T Powi(T a, tensorflow::int32 b) { + const bool recip = b < 0; + T r = 1; + while (true) { + if (b & 1) r *= a; + b /= 2; + if (b == 0) break; + a *= a; + } + return recip ? 1 / r : r; +} + +float TF_ATTRIBUTE_WEAK __powisf2(float a, tensorflow::int32 b) { + return Powi(a, b); +} + +double TF_ATTRIBUTE_WEAK __powidf2(double a, tensorflow::int32 b) { + return Powi(a, b); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_pow.h b/tensorflow/compiler/xla/service/cpu/runtime_pow.h new file mode 100644 index 00000000000..53f8094256d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_pow.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_POW_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_POW_H_ + +#include "tensorflow/core/platform/types.h" + +// Raises F32 value a to the power of b. +extern "C" float __powisf2(float a, tensorflow::int32 b); + +// Raises F64 value a to the power of b. +extern "C" double __powidf2(double a, tensorflow::int32 b); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_POW_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 153bd572eba..395eb31c13f 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_pow.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -56,9 +57,8 @@ llvm::SmallVector DetectMachineAttributes() { llvm::StringMap host_features; if (llvm::sys::getHostCPUFeatures(host_features)) { for (auto& feature : host_features) { - if (feature.second) { - result.push_back(std::string(feature.first())); - } + result.push_back((feature.second ? '+' : '-') + + std::string(feature.first())); } } return result; @@ -271,6 +271,8 @@ bool RegisterKnownJITSymbols() { "Host"); registry->Register("__truncdfhf2", reinterpret_cast(__truncdfhf2), "Host"); + registry->Register("__powisf2", reinterpret_cast(__powisf2), "Host"); + registry->Register("__powidf2", reinterpret_cast(__powidf2), "Host"); #undef REGISTER_CPU_RUNTIME_SYMBOL diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index e4676141f65..caea9d9095a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -109,10 +109,14 @@ class DfsHloVisitorBase { virtual Status HandleRsqrt(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleCbrt(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; + virtual Status HandleAllGather(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index baa9240fb56..9cd220245ba 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -98,6 +98,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCholesky(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleAllGather(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index be26f9a50cd..e193df6d9bd 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -1620,6 +1620,24 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, return Status::OK(); } +bool DynamicDimensionInference::HasDynamicDimension( + HloInstruction* inst) const { + bool has_dynamic_dim = false; + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsTuple()) { + return; + } + for (int64 i = 0; i < subshape.dimensions_size(); ++i) { + HloInstruction* operand_dynamic_size = GetDynamicSize(inst, index, i); + if (operand_dynamic_size != nullptr) { + has_dynamic_dim = true; + } + } + }); + return has_dynamic_dim; +} + HloInstruction* DynamicDimensionInference::GetDynamicSize( HloInstruction* inst, const ShapeIndex& index, int64 dim) const { auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim}); diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 6e3b9e26feb..417f0289143 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -51,6 +51,10 @@ class DynamicDimensionInference { HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim) const; + // Returns if current instruction contains any dynamic dimension. Recursively + // go into tuples. + bool HasDynamicDimension(HloInstruction* inst) const; + // Forward dynamic dimension size at `dim` and its constraint from `inst` to // `new_inst`. Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst, diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 09b15781b32..44fdda0f411 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace xla { @@ -943,106 +944,6 @@ Status InsertPadToStaticAfterModuleInputs(HloModule* module) { return Status::OK(); } -// For all dynamic outputs that live out of the computation, add -// slice-to-dynamic operations. -Status InsertSliceToDynamicBeforeModuleOutputs( - const DynamicDimensionInference& dynamic_dimension_inference, - HloModule* module) { - auto root = module->entry_computation()->root_instruction(); - absl::flat_hash_set dynamic_outputs; - ShapeUtil::ForEachSubshape( - root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (subshape.IsArray()) { - bool has_dynamic_output = false; - for (int64 dim = 0; dim < subshape.rank(); ++dim) { - if (dynamic_dimension_inference.GetDynamicSize(root, index, dim) != - nullptr) { - CHECK_LE(index.size(), 1) << "XLA doesn't support nested output " - "dimension that has dynamic size"; - has_dynamic_output = true; - } - } - if (has_dynamic_output) { - dynamic_outputs.insert(index); - } - } - }); - if (!dynamic_outputs.empty()) { - if (root->shape().IsTuple()) { - std::vector new_root_operands; - ShapeUtil::ForEachSubshape(root->shape(), [&](const Shape& subshape, - const ShapeIndex& index) { - if (!subshape.IsArray()) { - return; - } - - auto gte = module->entry_computation()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShapeWithStaticDimensions(subshape), root, - index[0])); - - if (dynamic_outputs.contains(index)) { - CHECK_EQ(index.size(), 1) - << "XLA only support 1 layer nested output tuple"; - // For dynamic outputs, creates an slice operation. - std::vector slice_operands; - // First operand is the original input. Rest are dimension values. - slice_operands.push_back(gte); - // Keep a dynamic version of the subshape as we are removing the - // dynamic dimension in the original root and gte. - Shape dynamic_subshape = subshape; - for (int64 dim = 0; dim < subshape.rank(); ++dim) { - HloInstruction* dynamic_size = - dynamic_dimension_inference.GetDynamicSize(root, index, dim); - if (dynamic_size != nullptr) { - slice_operands.push_back(dynamic_size); - } else { - auto const_size = HloInstruction::CreateConstant( - LiteralUtil::CreateR0(subshape.dimensions(dim))); - slice_operands.push_back( - module->entry_computation()->AddInstruction( - std::move(const_size))); - } - } - // This is a dynamic output, add slice operation. - auto slice = HloInstruction::CreateCustomCall( - dynamic_subshape, slice_operands, "SliceToDynamic"); - new_root_operands.push_back( - module->entry_computation()->AddInstruction(std::move(slice))); - } else { - new_root_operands.push_back(gte); - } - }); - - auto new_root = module->entry_computation()->AddInstruction( - HloInstruction::CreateTuple(new_root_operands)); - module->entry_computation()->set_root_instruction(new_root); - } else { - std::vector slice_operands; - // First operand is the original input. Rest are dimension values. - slice_operands.push_back(root); - for (int64 dim = 0; dim < root->shape().rank(); ++dim) { - HloInstruction* dynamic_size = - dynamic_dimension_inference.GetDynamicSize(root, {}, dim); - if (dynamic_size != nullptr) { - slice_operands.push_back(dynamic_size); - } else { - auto const_size = HloInstruction::CreateConstant( - LiteralUtil::CreateR0(root->shape().dimensions(dim))); - slice_operands.push_back(module->entry_computation()->AddInstruction( - std::move(const_size))); - } - // This is a dynamic output, add slice operation. - auto slice = module->entry_computation()->AddInstruction( - HloInstruction::CreateCustomCall(root->shape(), slice_operands, - "SliceToDynamic", "0-0")); - module->entry_computation()->set_root_instruction(slice); - } - } - } - return Status::OK(); -} - // Remove all dynamic shapes between pad-to-static and slice-to-dynamic. // // After this visitor the entry computation then looks like: @@ -1059,46 +960,217 @@ Status InsertSliceToDynamicBeforeModuleOutputs( // ROOT tuple (dynamic) class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault { public: + explicit DynamicShapeRemovingVisitor( + const DynamicPadder::OpSupportsDynamismHandler& + op_supports_dynamism_handler, + const DynamicDimensionInference& dynamic_dimension_inference) + : op_supports_dynamism_handler_(op_supports_dynamism_handler), + dynamic_dimension_inference_(dynamic_dimension_inference) {} + Status DefaultAction(HloInstruction* hlo) override; Status HandleCustomCall(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleGetTupleElement(HloInstruction* hlo) override; + Status HandleParameter(HloInstruction* hlo) override; - static Status Run(HloComputation* computation) { - DynamicShapeRemovingVisitor visitor; - return computation->Accept(&visitor); + static Status Run(HloComputation* computation, + const DynamicPadder::OpSupportsDynamismHandler& + op_supports_dynamism_handler, + const DynamicDimensionInference& dynamic_shape_inference, + bool require_dynamic_output) { + DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler, + dynamic_shape_inference); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + // If the outputs is required to be dynamic form, insert static to dynamic + // conversion as root. + if (require_dynamic_output) { + HloInstruction* root = computation->root_instruction(); + if (dynamic_shape_inference.HasDynamicDimension(root)) { + HloInstruction* new_root = visitor.ConvertToDynamic(root); + computation->set_root_instruction(new_root); + } + } + return Status::OK(); } + + private: + // If a tensor produced by `inst` is in dynamic form, convert it to static and + // returns the new instruction. + HloInstruction* ConvertToStatic(HloInstruction* inst); + + // If a tensor produced by `inst` is in static form, convert it to dynamic and + // returns the new instruction. + HloInstruction* ConvertToDynamic(HloInstruction* inst); + + const DynamicPadder::OpSupportsDynamismHandler& op_supports_dynamism_handler_; + + const DynamicDimensionInference& dynamic_dimension_inference_; }; +HloInstruction* DynamicShapeRemovingVisitor::ConvertToDynamic( + HloInstruction* inst) { + auto* comp = inst->parent(); + const Shape& shape = inst->shape(); + if (shape.IsTuple()) { + std::vector dynamic_operands; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + auto operand = inst->mutable_operand(i); + if (dynamic_dimension_inference_.HasDynamicDimension(operand)) { + // Recurse. + dynamic_operands.push_back(ConvertToDynamic(operand)); + } else { + dynamic_operands.push_back(operand); + } + } + return comp->AddInstruction(HloInstruction::CreateTuple(dynamic_operands)); + } else { + // Collect the data input, as well as dimension sizes, and feed them to + // slice to dynamic to create a dynamic tensor. + Shape output_shape = shape; // 0th element. + CHECK(output_shape.is_static()); + std::vector slice_operand; + slice_operand.push_back(inst); + for (int64 i = 0; i < output_shape.dimensions_size(); ++i) { + auto dimension_size = + dynamic_dimension_inference_.GetDynamicSize(inst, {}, i); + if (dimension_size == nullptr) { + dimension_size = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(output_shape.dimensions(i)))); + } else { + output_shape.set_dynamic_dimension(i, true); + } + slice_operand.push_back(dimension_size); + } + return comp->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, slice_operand, "SliceToDynamic")); + } +} + +HloInstruction* DynamicShapeRemovingVisitor::ConvertToStatic( + HloInstruction* inst) { + auto* comp = inst->parent(); + const Shape& shape = inst->shape(); + CHECK(shape.is_dynamic()); + if (shape.IsTuple()) { + std::vector static_operands; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + auto operand = inst->mutable_operand(i); + if (shape.tuple_shapes(i).is_dynamic()) { + static_operands.push_back(ConvertToStatic(operand)); + } else { + static_operands.push_back(operand); + } + } + return comp->AddInstruction(HloInstruction::CreateTuple(static_operands)); + } else { + // The output shape of pad static is a tuple. The 0th element is the data + // output, which is the same as input shape, but without dynamic dimensions. + // i-th element is the dynamic dimension size for i-1th input dimension. + Shape data_output_shape = shape; // 0th element. + data_output_shape.clear_dynamic_dimensions(); + Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape}); + for (int64 i = 0; i < shape.rank(); ++i) { + ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32), + &output_shape); + } + HloInstruction* pad_to_static = + comp->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {inst}, "PadToStatic", "")); + HloInstruction* data_output = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + data_output_shape, pad_to_static, 0)); + return data_output; + } +} + Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) { - // Default rule: If input to an op is static, remove dynamism in output. - bool input_is_dynamic = false; - // Default rule: - for (int64 i = 0; i < hlo->operand_count(); ++i) { - if (!hlo->operand(i)->shape().is_static()) { - input_is_dynamic = true; + const bool input_is_dynamic = absl::c_any_of( + hlo->operands(), + [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); }); + + // By default, ops don't support dynamic lowering. + OpDynamismSupport op_support = OpDynamismSupport::kNoSupport; + if (op_supports_dynamism_handler_) { + op_support = op_supports_dynamism_handler_(hlo); + } + if (op_support == OpDynamismSupport::kNoSupport) { + for (auto* sub_computation : hlo->called_computations()) { + for (auto* param : sub_computation->parameter_instructions()) { + param->mutable_shape()->clear_dynamic_dimensions(); + } } } - - if (!input_is_dynamic) { + // If the input to an op is static and the op doesn't support + // dynamic output, remove dynamism in output -- dynamic_padder should have + // rewritten it to support static shapes. + if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) { hlo->mutable_shape()->clear_dynamic_dimensions(); + return Status::OK(); } + + // Op doesn't support dynamic tensor: For each operand rewrite dynamic input + // into static input using pad_to_static. + if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) { + VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().is_dynamic()) { + auto static_operand = ConvertToStatic(hlo->mutable_operand(i)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand)); + } + } + // This op doesn't support dynamic lowering so the op has to be static. + hlo->mutable_shape()->clear_dynamic_dimensions(); + return Status::OK(); + } + + // If the op requires dynamic tensor and input is static -- construct a + // dynamic tensor from the static tensor to feed it. + if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) { + VLOG(1) << "op doesn't support static tensor: " << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto operand = hlo->mutable_operand(i); + if (dynamic_dimension_inference_.HasDynamicDimension(operand)) { + auto dynamic_operand = ConvertToDynamic(hlo->mutable_operand(i)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand)); + } + } + return Status::OK(); + } + return Status::OK(); } -Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) { - if (hlo->custom_call_target() == "SliceToDynamic") { - // Don't remove slice-to-dynamic instruction. - return Status::OK(); +Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) { + *hlo->mutable_shape() = + hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index()); + return Status::OK(); +} + +Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + *hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape(); } - return DefaultAction(hlo); + return Status::OK(); } Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) { return Status::OK(); } +Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "SliceToDynamic" || + hlo->custom_call_target() == "PadToStatic") { + // Those ops support are created to handle dynamic tensors so by their + // nature they support dynamic lowering. + return Status::OK(); + } + + return DefaultAction(hlo); +} + } // namespace StatusOr DynamicPadder::Run(HloModule* module) { @@ -1137,11 +1209,20 @@ StatusOr DynamicPadder::Run(HloModule* module) { })); TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module)); - TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, - DynamicDimensionInference::Run(module)); + TF_ASSIGN_OR_RETURN( + DynamicDimensionInference dynamic_dimension_inference, + DynamicDimensionInference::Run(module, custom_call_handler_)); for (HloComputation* computation : module->computations()) { for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + OpDynamismSupport has_dynamism_support = OpDynamismSupport::kNoSupport; + if (op_supports_dynamism_handler_ != nullptr) { + has_dynamism_support = op_supports_dynamism_handler_(inst); + } + // This op support dynamic lowering, no padding is required. + if (has_dynamism_support != OpDynamismSupport::kNoSupport) { + continue; + } if (inst->opcode() == HloOpcode::kConcatenate) { TF_ASSIGN_OR_RETURN( changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference)); @@ -1152,6 +1233,11 @@ StatusOr DynamicPadder::Run(HloModule* module) { changed, RewriteDynamicSort(inst, &dynamic_dimension_inference)); continue; } + if (inst->opcode() == HloOpcode::kReshape) { + TF_ASSIGN_OR_RETURN( + changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference)); + continue; + } for (int64 operand_num = 0; operand_num < inst->operand_count(); ++operand_num) { HloInstruction* original_operand = inst->mutable_operand(operand_num); @@ -1160,11 +1246,6 @@ StatusOr DynamicPadder::Run(HloModule* module) { continue; } - if (inst->opcode() == HloOpcode::kReshape) { - TF_ASSIGN_OR_RETURN(changed, RewriteDynamicReshape( - inst, &dynamic_dimension_inference)); - continue; - } for (int64 input_dim = 0; input_dim < operand->shape().rank(); ++input_dim) { HloInstruction* operand_dynamic_size = @@ -1195,37 +1276,28 @@ StatusOr DynamicPadder::Run(HloModule* module) { } } } - if (slice_dynamic_output_) { - TF_RETURN_IF_ERROR(InsertSliceToDynamicBeforeModuleOutputs( - dynamic_dimension_inference, module)); - } - // Remove all dynamic dimensions after entry parameter and root instruction -- - // Dynamic padder will produce an equivalent static shaped graph. - for (HloComputation* computation : module->computations()) { - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(computation)); - } else { - for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { - bool operand_is_dynamic = false; - for (auto* operand : inst->operands()) { - if (!operand->shape().is_static()) { - operand_is_dynamic = true; - } - } - if (!operand_is_dynamic) { - inst->mutable_shape()->clear_dynamic_dimensions(); - } - } - } + // There are ops that only support dynamic lowering and ops that only support + // static lowering, add dynamic<->static tensor conversion around the boundary + // between those ops, as well as the root instruction. + auto computations = module->MakeComputationPostOrder(); + // Reverse postorder so that if caller doesn't support dynamic tensor (while, + // etc), change their called computation to only take static tensors. + for (auto it = computations.rbegin(); it != computations.rend(); ++it) { + HloComputation* computation = *it; + // if slice_dynamic_output_ is set and this is entry computation, we need + // the output tensor to be in dynamic form. + bool require_dynamic_output = + slice_dynamic_output_ && computation == module->entry_computation(); + TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run( + computation, op_supports_dynamism_handler_, dynamic_dimension_inference, + /*require_dynamic_output=*/require_dynamic_output)); } HloDCE dce; TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); - VLOG(2) << "Post DynamicPadder HLO:"; XLA_VLOG_LINES(2, module->ToString()); - return changed; } diff --git a/tensorflow/compiler/xla/service/dynamic_padder.h b/tensorflow/compiler/xla/service/dynamic_padder.h index f0f3eed0a26..ca2513eaa5c 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.h +++ b/tensorflow/compiler/xla/service/dynamic_padder.h @@ -36,12 +36,38 @@ namespace xla { // Dynamic_padder removes dynamic shapes from the entry computation, and inserts // custom calls (with dynamic shapes), which are lowered by specialized // emitters: PadToStatic and SliceToDynamic. + +// Each instruction can have one of the three modes in supporting dynamic +// lowering. +enum OpDynamismSupport { + // There is no support for dynamic lowering -- dynamic padder will make sure + // the input to that op has static bound by rewriting the op (e.g, extra space + // in reduce_sum will be padded with 0). + kNoSupport = 0, + // The op can take either dynamic input or static input. + kOptional, + // The op only has a dynamic lowering, dynamic padder will make sure the input + // to this op is in dynamic form. + kRequired, +}; + class DynamicPadder : public HloModulePass { public: + // Returns true if given instruction supports native dynamic lowering. If so, + // dynamic padder will not attempt to pad it. + using OpSupportsDynamismHandler = + std::function; + // If `slice_dynamic_output` is true, insert 'slice_to_dynamic' ops to all // outputs that are inferred to be dynamic. - explicit DynamicPadder(bool slice_dynamic_output = true) - : slice_dynamic_output_(slice_dynamic_output) {} + explicit DynamicPadder( + bool slice_dynamic_output = true, + DynamicDimensionInference::CustomCallInferenceHandler + custom_call_handler = nullptr, + OpSupportsDynamismHandler op_supports_dynamism_handler = nullptr) + : slice_dynamic_output_(slice_dynamic_output), + custom_call_handler_(custom_call_handler), + op_supports_dynamism_handler_(op_supports_dynamism_handler) {} absl::string_view name() const override { return "dynamic_padder"; } @@ -51,6 +77,13 @@ class DynamicPadder : public HloModulePass { // Insert 'slice_to_dynamic' ops to all outputs that are inferred to be // dynamic. bool slice_dynamic_output_; + + // A handler for dynamic dimension inference of custom calls. + DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_; + + // A handler to indicate if a given hlo instruction support native dynamism + // lowering. + OpSupportsDynamismHandler op_supports_dynamism_handler_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index 31ae1ab60fd..e4c70317f2b 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -44,12 +44,49 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { +OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) { + if (hlo->opcode() != HloOpcode::kCustomCall) { + return OpDynamismSupport::kNoSupport; + } + if (hlo->custom_call_target() == "OpWithDynamicLowering") { + return OpDynamismSupport::kRequired; + } + return OpDynamismSupport::kNoSupport; +} + +Status CustomCallDynamicDimensionInference( + HloInstruction* hlo, DynamicDimensionInference* inferencer) { + if (hlo->custom_call_target() == "OpWithDynamicLowering") { + if (hlo->shape().IsTuple()) { + // Use the operand's dynamic size as output dynamic size. + HloInstruction* dynamic_size = + inferencer->GetDynamicSize(hlo->mutable_operand(0), {1}, 0); + inferencer->SetDynamicSize(hlo, {1}, 0, dynamic_size); + } else { + // Use the operand's dynamic size as output dynamic size. + HloInstruction* dynamic_size = + inferencer->GetDynamicSize(hlo->mutable_operand(0), {}, 0); + inferencer->SetDynamicSize(hlo, {}, 0, dynamic_size); + } + } + + return Status::OK(); +} + class DynamicPadderTest : public HloTestBase { protected: DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); } + std::unique_ptr GetHloModule(const string& hlo_text) { + std::unique_ptr module = + ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + return module; + } + StatusOr RunPadder() { - DynamicPadder padder; + DynamicPadder padder(/*slice_dynamic_output=*/true, + CustomCallDynamicDimensionInference, + OpHasDynamismSupport); return padder.Run(module_.get()); } @@ -105,6 +142,120 @@ TEST_F(DynamicPadderTest, ReduceTest) { ExpectPadded(reduce->operand(0)); } +TEST_F(DynamicPadderTest, DynamicLoweringTest) { + const string hlo_text = R"( +HloModule DynamicLowering + +ENTRY main { + param = s32[5] parameter(0) + const = s32[] constant(3) + param_padded = s32[<=5] set-dimension-size(param, const), + dimensions={0} + custom-call.1 = s32[<=5] custom-call(param_padded), + custom_call_target="OpWithDynamicLowering" + custom-call.2 = s32[<=5] custom-call(custom-call.1), + custom_call_target="OpWithDynamicLowering" + // Negate doesn't support dynamic lowering. + ROOT negate = s32[<=5] negate(custom-call.2) +} +)"; + + module_ = GetHloModule(hlo_text); + + TF_ASSERT_OK(RunPadder().status()); + // After rewrite, we should have : + // + // param + // | + // SliceToDynamic + // | + // OpWithDynamicLowering (custom_call_1) + // | + // OpWithDynamicLowering (custom_call_2) + // | + // PadToStatic + // | + // Negate + // | + // SliceToDynamic // Root require dynamic form tensor. + auto custom_call_1 = + module_->entry_computation()->GetInstructionWithName("custom-call.1"); + auto custom_call_2 = + module_->entry_computation()->GetInstructionWithName("custom-call.2"); + // Test that the input to custom call + HloInstruction* slice_to_dynamic = custom_call_1->mutable_operand(0); + ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall); + ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic"); + ASSERT_EQ(custom_call_2->user_count(), 1); + HloInstruction* pad_to_static = custom_call_2->users()[0]; + ASSERT_THAT(pad_to_static->opcode(), HloOpcode::kCustomCall); + ASSERT_THAT(pad_to_static->custom_call_target(), "PadToStatic"); + slice_to_dynamic = module_->entry_computation()->root_instruction(); + ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall); + ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic"); +} + +TEST_F(DynamicPadderTest, DynamicLoweringTestTupleInput) { + const string hlo_text = R"( +HloModule DynamicLowering + +ENTRY main { + param = s32[5] parameter(0) + const = s32[] constant(3) + param_padded = s32[<=5] set-dimension-size(param, const), + dimensions={0} + // Create a tuple with static and dynamic componenet. + tuple_arg = (s32[], s32[<=5]) tuple(const, param_padded) + custom-call.1 = (s32[], s32[<=5]) custom-call(tuple_arg), + custom_call_target="OpWithDynamicLowering" + custom-call.2 = (s32[], s32[<=5]) custom-call(custom-call.1), + custom_call_target="OpWithDynamicLowering" + data = s32[<=5]{0} get-tuple-element(custom-call.2), index=1 + // Negate doesn't support dynamic lowering. + ROOT negate = s32[<=5] negate(data) +} +)"; + + module_ = GetHloModule(hlo_text); + + TF_ASSERT_OK(RunPadder().status()); + // After rewrite, we should have : + // + // param + // | + // SliceToDynamic + // | + // Tuple + // | + // OpWithDynamicLowering (custom_call_1) + // | + // OpWithDynamicLowering (custom_call_2) + // | + // GTE + // | + // PadToStatic + // | + // Negate + // | + // SliceToDynamic // Root require dynamic form tensor. + + auto* root = module_->entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::CustomCall("SliceToDynamic", op::Negate(), op::Constant())); + HloInstruction* negate = root->mutable_operand(0); + EXPECT_THAT( + negate, + op::Negate(op::GetTupleElement(op::CustomCall( + "PadToStatic", op::GetTupleElement(op::CustomCall( + "OpWithDynamicLowering", ::testing::_)))))); + auto custom_call_1 = + module_->entry_computation()->GetInstructionWithName("custom-call.1"); + EXPECT_THAT(custom_call_1, + op::CustomCall( + "OpWithDynamicLowering", + op::Tuple(op::Constant(), op::CustomCall("SliceToDynamic")))); +} + TEST_F(DynamicPadderTest, ConvolutionTest) { auto builder = HloComputation::Builder(TestName()); constexpr int xdim = 3; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 3eb6dab3129..8cb660de46c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -461,6 +461,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return EmitSqrt(op->shape().element_type(), operand_value); case HloOpcode::kRsqrt: return EmitRsqrt(op->shape().element_type(), operand_value); + case HloOpcode::kCbrt: + return EmitCbrt(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {operand_value}, @@ -787,6 +789,9 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( case HloOpcode::kRsqrt: { return EmitComplexRsqrt(op, component_type, operand_value); } + case HloOpcode::kCbrt: { + return EmitComplexCbrt(op, component_type, operand_value); + } case HloOpcode::kNegate: return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), FNeg(EmitExtractImag(operand_value))); @@ -1081,6 +1086,19 @@ StatusOr ElementalIrEmitter::EmitComplexRsqrt( return EmitComposeComplex(op, real_part, imag_part); } +// +// Using EmitComplexPower with c=1.0/3.0 and d=0 +StatusOr ElementalIrEmitter::EmitComplexCbrt( + const HloInstruction* op, PrimitiveType prim_type, + llvm::Value* operand_value) { + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); + auto zero = llvm::ConstantFP::get(type, 0); + llvm::Value* a = EmitExtractReal(operand_value); + llvm::Value* b = EmitExtractImag(operand_value); + return EmitComplexPower(op, a, b, third, zero); +} + // (a+bi)^(c+di) = // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) @@ -1392,6 +1410,19 @@ StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, {lhs->getType()}, b_); } +StatusOr ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, + llvm::Value* value) { + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto third = llvm::ConstantFP::get(type, 1.0 / 3.0); + auto abs_value = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); + TF_ASSIGN_OR_RETURN(llvm::Value * abs_res, + EmitPow(prim_type, abs_value, third)); + auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, + {abs_res, value}, {type}, b_); + return signed_res; +} + StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) { @@ -2181,6 +2212,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { @@ -2390,6 +2422,43 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( -> StatusOr { return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; + case HloOpcode::kMap: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + std::vector operands; + for (int i = 0; i < hlo->operand_count(); i++) { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))(index)); + operands.push_back(operand_value); + } + std::vector input_generators; + for (const HloInstruction* instr : hlo->operands()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + return EmitElementalMap(Cast(hlo), operands); + }; + case HloOpcode::kReduceWindow: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return EmitElementalReduceWindow( + Cast(hlo), + operand_to_generator.at(hlo->operand(0)), + operand_to_generator.at(hlo->operand(1)), index); + }; + case HloOpcode::kReduce: + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + auto reduce_instr = Cast(hlo); + std::vector input_generators; + for (const HloInstruction* instr : reduce_instr->inputs()) { + input_generators.push_back(operand_to_generator.at(instr)); + } + + std::vector initial_value_generators; + for (const HloInstruction* instr : reduce_instr->init_values()) { + initial_value_generators.push_back(operand_to_generator.at(instr)); + } + return EmitElementalReduce(reduce_instr, std::move(input_generators), + std::move(initial_value_generators), index); + }; default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", @@ -2419,4 +2488,215 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, return complex; } +StatusOr ElementalIrEmitter::EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands) { + TF_ASSIGN_OR_RETURN( + std::vector values, + EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands, + llvm_ir::IrName(map_instr))); + CHECK_EQ(values.size(), 1); + return values[0]; +} + +StatusOr ElementalIrEmitter::EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index) { + // Pseudocode: + // for each index I in output + // value = init_value + // for each index W in window + // for each dimension i from 0 to rank - 1 + // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i] + // if I in bounds of input + // value = function(value, input[I]) + // output[O] = value + const HloInstruction* operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + + PrimitiveType operand_element_type = operand->shape().element_type(); + llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), + "reduce_window_accum_ptr", b_); + { + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generator(llvm_ir::IrArray::Index(index.GetType()))); + Store(init_value, accum_ptr); + } + + llvm::Type* index_type = index.GetType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return index.GetConstantWithIndexType(c); + }; + + llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const IrArray::Index window_index = loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + CHECK_EQ(window_index.size(), index.size()); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); + + std::vector input_multi_index(index.size()); + llvm::Value* in_bounds = b_->getInt1(true); + for (size_t i = 0; i < index.size(); ++i) { + llvm::Value* stridden_index = + NSWMul(index[i], index_typed_const(window.dimensions(i).stride())); + input_multi_index[i] = NSWSub( + NSWAdd( + stridden_index, + NSWMul(window_index[i], + index_typed_const(window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = + ICmpEQ(SRem(input_multi_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. + input_multi_index[i] = + SDiv(input_multi_index[i], + index_typed_const(window.dimensions(i).base_dilation())); + + // We must check whether 0 <= input_multi_index[i] < bound, as + // otherwise we are in the pad and so can skip the computation. This + // comparison is equivalent to the unsigned comparison + // input_multi_index[i] < bound, as a negative value wraps to a large + // positive value. + in_bounds = And(in_bounds, + ICmpULT(input_multi_index[i], + index_typed_const(operand->shape().dimensions(i)))); + } + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); + + // We are not in pad, so do the computation. + IrArray::Index input_index(input_multi_index, operand->shape(), index_type); + TF_ASSIGN_OR_RETURN(llvm::Value * input_value, input_generator(input_index)); + TF_ASSIGN_OR_RETURN( + std::vector accum_values, + EmitThreadLocalCall(*reduce_window->to_apply(), + {Load(accum_ptr), input_value}, "reducer_function")); + CHECK_EQ(accum_values.size(), 1); + Store(accum_values[0], accum_ptr); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); + return Load(accum_ptr); +} + +StatusOr ElementalIrEmitter::EmitElementalReduce( + const HloReduceInstruction* reduce, + std::vector input_generators, + std::vector initial_value_generators, + const llvm_ir::IrArray::Index& index) { + const Shape& out_shape = reduce->shape(); + bool is_variadic = !out_shape.IsArray(); + int accumulators_count = 1; + if (is_variadic) { + CHECK(out_shape.IsTuple()); + accumulators_count = out_shape.tuple_shapes_size(); + } + + absl::Span reduced_dimensions(reduce->dimensions()); + + std::vector accumulator_addrs; + std::vector accumulator_types; + llvm::Type* index_type = index.GetType(); + for (int i = 0; i < accumulators_count; i++) { + const Shape& element_shape = + is_variadic ? out_shape.tuple_shapes(i) : out_shape; + PrimitiveType accumulator_type = element_shape.element_type(); + llvm::Type* accumulator_llvm_type = + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_); + accumulator_types.push_back(accumulator_llvm_type); + + // Initialize an accumulator with init_value. + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + accumulator_llvm_type, "accumulator_" + std::to_string(i), b()); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_value, + initial_value_generators[i](llvm_ir::IrArray::Index(index_type))); + Store(init_value, accumulator_addr); + accumulator_addrs.push_back(accumulator_addr); + } + + // The enclosing loops go over all the target elements. Now we have to compute + // the actual target element. For this, we build a new loop nest to iterate + // over all the reduction dimensions in the argument. + // AddLoopsForShapeOnDimensions will return an Index where induction Value*s + // are placed for each dimension in dimensions, and all the rest are nullptrs. + llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type); + const HloInstruction* arg = reduce->operand(0); + std::vector input_multi_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); + + // Build a full index for the input argument, using input_multi_index as the + // base. In input_multi_index only the reduction dimensions are filled in. We + // fill in the rest of the dimensions with induction Value*s taken from + // 'index' which iterates over the target array. See the high-level + // description in the XLA documentation for details. + auto it = index.begin(); + + for (auto& i : input_multi_index) { + if (i == nullptr) { + i = *it++; + } + } + CHECK(index.end() == it); + llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(), + index_type); + + std::vector reduction_operands; + for (llvm::Value* accum : accumulator_addrs) { + llvm::Value* accum_value = Load(accum); + reduction_operands.push_back(accum_value); + } + + for (int i = 0; i < accumulators_count; i++) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_element, + input_generators[i](input_index)); + reduction_operands.push_back(input_element); + } + + TF_ASSIGN_OR_RETURN( + std::vector results, + EmitThreadLocalCall(*reduce->to_apply(), reduction_operands, + "reduce_function")); + + CHECK(results.size() == accumulators_count); + for (int i = 0; i < accumulators_count; i++) { + Store(results[i], accumulator_addrs[i]); + } + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); + + if (is_variadic) { + // Emit a structure, as that what the LoopEmitter expects. + llvm::Value* returned_structure = llvm::UndefValue::get( + llvm::StructType::get(b()->getContext(), accumulator_types)); + for (int i = 0; i < accumulators_count; i++) { + llvm::Value* accumulator_value = Load(accumulator_addrs[i]); + returned_structure = + b()->CreateInsertValue(returned_structure, accumulator_value, i); + } + return returned_structure; + } else { + CHECK_EQ(accumulator_addrs.size(), 1); + return Load(accumulator_addrs[0]); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 99833a5525f..06a9d7b194c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -17,12 +17,17 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ #include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" @@ -116,6 +121,9 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitSqrt(PrimitiveType prim_type, llvm::Value* value); + virtual StatusOr EmitCbrt(PrimitiveType prim_type, + llvm::Value* value); + virtual StatusOr EmitRsqrt(PrimitiveType prim_type, llvm::Value* value); @@ -159,6 +167,10 @@ class ElementalIrEmitter : public IrBuilderMixin { PrimitiveType prim_type, llvm::Value* operand_value); + virtual StatusOr EmitComplexCbrt(const HloInstruction* op, + PrimitiveType prim_type, + llvm::Value* operand_value); + virtual StatusOr EmitComplexRsqrt(const HloInstruction* op, PrimitiveType prim_type, llvm::Value* operand_value); @@ -213,6 +225,26 @@ class ElementalIrEmitter : public IrBuilderMixin { const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& dot_result_index); + virtual StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view name) = 0; + + StatusOr EmitElementalMap( + const HloMapInstruction* map_instr, + absl::Span elemental_operands); + + StatusOr EmitElementalReduceWindow( + const HloReduceWindowInstruction* reduce_window, + const llvm_ir::ElementGenerator& input_generator, + const llvm_ir::ElementGenerator& initial_value_generator, + const llvm_ir::IrArray::Index& index); + + StatusOr EmitElementalReduce( + const HloReduceInstruction* reduce, + std::vector input_generators, + std::vector initial_value_generators, + const llvm_ir::IrArray::Index& index); + llvm::IRBuilder<>* const b_; llvm::Module* module_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 61bc41283e1..0f6b2cb72e6 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -684,7 +684,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", @@ -720,7 +720,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -1674,7 +1674,7 @@ tf_proto_library_cc( protodeps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:autotuning_proto", + "//tensorflow/core/protobuf:autotuning_proto", ], ) @@ -1685,8 +1685,8 @@ cc_library( deps = [ ":gpu_autotuning_proto_cc", "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/core:autotuning_proto_cc", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/protobuf:autotuning_proto_cc", "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c6df786fb51..1be0b1b4e7b 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -305,168 +305,5 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() { return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } -llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) { - switch (hlo->opcode()) { - case HloOpcode::kMap: - return [=, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - TF_RET_CHECK(!hlo->operands().empty()) - << "Zero operand map not implemented in GPU backend."; - TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0); - std::vector operand_elements; - for (HloInstruction* operand : hlo->operands()) { - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(index)); - operand_elements.push_back(value); - } - return compute_nested_(*hlo->to_apply(), operand_elements); - }; - case HloOpcode::kReduceWindow: - // Pseudocode: - // for each index I in output - // value = init_value - // for each index W in window - // for each dimension i from 0 to rank - 1 - // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i] - // if I in bounds of input - // value = function(value, input[I]) - // output[O] = value - return [=, &operand_to_generator]( - const IrArray::Index& index) -> StatusOr { - const HloInstruction* operand = hlo->operand(0); - const Window& window = hlo->window(); - - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accum_ptr", b_); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))( - IrArray::Index(index.GetType()))); - Store(init_value, accum_ptr); - } - - llvm::Type* index_type = index.GetType(); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { - return index.GetConstantWithIndexType(c); - }; - - llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); - - std::vector input_multi_index(index.size()); - llvm::Value* in_bounds = b_->getInt1(true); - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = NSWMul( - index[i], index_typed_const(window.dimensions(i).stride())); - input_multi_index[i] = NSWSub( - NSWAdd(stridden_index, - NSWMul(window_index[i], - index_typed_const( - window.dimensions(i).window_dilation()))), - index_typed_const(window.dimensions(i).padding_low())); - - // We need to verify that we are not in the dilated base area. - llvm::Value* dilation_condition = ICmpEQ( - SRem(input_multi_index[i], - index_typed_const(window.dimensions(i).base_dilation())), - index_typed_const(0)); - in_bounds = And(in_bounds, dilation_condition); - - // Apply base dilation to the index. - input_multi_index[i] = - SDiv(input_multi_index[i], - index_typed_const(window.dimensions(i).base_dilation())); - - // We must check whether 0 <= input_multi_index[i] < bound, as - // otherwise we are in the pad and so can skip the computation. This - // comparison is equivalent to the unsigned comparison - // input_multi_index[i] < bound, as a negative value wraps to a large - // positive value. - in_bounds = - And(in_bounds, - ICmpULT(input_multi_index[i], - index_typed_const(operand->shape().dimensions(i)))); - } - - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); - SetToFirstInsertPoint(if_data.true_block, b_); - - // We are not in pad, so do the computation. - IrArray::Index input_index(input_multi_index, operand->shape(), - index_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_value, - operand_to_generator.at(operand)(input_index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); - Store(accum_value, accum_ptr); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return Load(accum_ptr); - }; - case HloOpcode::kReduce: - // TODO(b/118332391): This should be supported. - CHECK_EQ(hlo->operand_count(), 2) << "Did not expect variadic reduce"; - return [=, &operand_to_generator]( - const IrArray::Index& output_index) -> StatusOr { - const HloInstruction* operand = hlo->operand(0); - llvm::Value* accum_ptr = - b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), module_)); - llvm::Type* index_type = output_index.GetType(); - TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))( - IrArray::Index(index_type))); - b()->CreateStore(init_value, accum_ptr); - - llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); - std::vector input_multi_index = - loops.AddLoopsForShapeOnDimensions( - operand->shape(), hlo->dimensions(), "reduction_dim"); - if (!ShapeUtil::IsScalar(hlo->shape())) { - // Here only input_multi_index[hlo->dimensions()] are non-null, so we - // must set the rest. - size_t j = 0; - for (auto& i : input_multi_index) { - if (i == nullptr) { - i = output_index[j++]; - } - } - CHECK_EQ(output_index.size(), j); - } - llvm_ir::IrArray::Index input_index( - input_multi_index, hlo->operand(0)->shape(), index_type); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); - TF_ASSIGN_OR_RETURN( - llvm::Value * input_value, - operand_to_generator.at(hlo->operand(0))(input_index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b()->CreateLoad(accum_ptr), input_value})); - b()->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); - return b()->CreateLoad(accum_ptr); - }; - default: - return ElementalIrEmitter::MakeElementGenerator(hlo, - operand_to_generator); - } -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index c8a58a21980..3c4e9f7c1e6 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -47,10 +47,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested); - llvm_ir::ElementGenerator MakeElementGenerator( - const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) override; - protected: StatusOr EmitFloatBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, @@ -92,6 +88,17 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitComplexAbs(PrimitiveType prim_type, llvm::Value* value) override; + StatusOr> EmitThreadLocalCall( + const HloComputation& callee, absl::Span parameters, + absl::string_view) override { + // TODO(b/118332391): Supported variadic return values. + auto result = compute_nested_(callee, parameters); + if (!result.ok()) { + return result.status(); + } + return std::vector{result.ValueOrDie()}; + } + llvm::Value* EmitThreadId() override; private: diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b6c1e671986..5f6dfd7d3a5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -409,6 +409,16 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass>(); } + // GemmRewriter assumes that all transposes are folded into gemms, but, + // since commit 7d529df, this is not always true at this point. + // Therefore, rerun transpose folding. + pipeline.AddPass( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return IsMatrixMultiplication(dot) ? candidate_operands + : TransposeFolding::OperandIndices{}; + }, + TransposeFolding::NeverFoldTranspose); // Rewrite GEMMs into custom calls. pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ec5f10bd2e8..a78ffc8dd1a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2016,7 +2016,9 @@ void IrEmitterUnnested::EmitTile( // True iff all threads always execute all instructions in the tiling // dimension X. - bool x_tile_fits = mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0; + bool x_tile_fits = + mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0 && + mapping_scheme.GetRowContiguous(); // The outer loop below is simply doing: // @@ -2731,7 +2733,8 @@ void IrEmitterUnnested::EmitHlo021Tile( /*num_threads_y=*/kNumRows, /*num_threads_x=*/kWarpSize, /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1); + /*vector_size=*/1, + /*is_row_contiguous=*/false); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); llvm::Type* index_type = diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h index 5e15d0767a1..d5c4ecbc795 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h @@ -90,13 +90,14 @@ class KernelMappingScheme { KernelMappingScheme(absl::Span dims_in_elems, absl::Span tile_sizes, int64 num_threads_y, int64 num_threads_x, IndexingOrder indexing_order, - int vector_size) + int vector_size, bool is_row_contiguous = false) : dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]}, tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y), indexing_order_(indexing_order), - vector_size_(vector_size) { + vector_size_(vector_size), + is_row_contiguous_(is_row_contiguous) { CHECK_EQ(tile_sizes[1] % num_threads_y_, 0); CHECK_EQ(tile_sizes[2] % num_threads_x_, 0); VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ","); @@ -134,6 +135,7 @@ class KernelMappingScheme { IndexingOrder GetIndexingOrder() const { return indexing_order_; } int GetVectorSize() const { return vector_size_; } + bool GetRowContiguous() const { return is_row_contiguous_; } private: // The number of elements in each dimension. @@ -159,6 +161,7 @@ class KernelMappingScheme { // to trigger vectorized loads on GPUs while keeping memory // coalescing. const int vector_size_; + const bool is_row_contiguous_; }; // Information to support the code generation for a tiled reduction kernel. diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 2d255d76746..aff9e6f162b 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include // NOLINT (required by TF interfaces) +#include #include #include #include @@ -85,6 +86,11 @@ namespace { using tensorflow::BlockingCounter; +bool IsGlobalNcclConfig() { + static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr; + return global_nccl_config; +} + // Functions to translate an ncclResult_t/cudaError_t to a Status object. Used // by the macros below. Status TranslateStatus(ncclResult_t s, const char* file, int64 line, @@ -285,7 +291,6 @@ class NcclClique { std::vector raw_comms(local_device_ordinals_.size(), nullptr); TF_ASSIGN_OR_RETURN(const absl::optional& nccl_id_string, maybe_nccl_unique_id); - ncclUniqueId nccl_id; if (nccl_id_string) { TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id)); @@ -416,10 +421,12 @@ RendezvousNcclAllReduce::SubmitParticipantImpl( nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key); } else { if (participant.rendezvous_key.global_devices.size() != - participant.rendezvous_key.num_local_participants) { + participant.rendezvous_key.num_local_participants && + !IsGlobalNcclConfig()) { nccl_unique_id = InvalidArgument( - "Multihost AllReduce on GPU requires a nccl_unique_id_callback " - "to be provided by the client."); + "If not local devices are taking part of a collective API on " + "GPU, the nccl_unique_id_callback must be provided by the " + "client."); } else { nccl_unique_id = absl::optional(); } @@ -568,6 +575,13 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { std::vector global_participating_replicas, GetParticipatingReplicas(global_device_id, instr->replica_groups(), replica_count_, *params.device_assn)); + if (IsGlobalNcclConfig() && + global_participating_replicas.size() != replica_count_) { + return InvalidArgument( + "Partial replica groups are not allowed when using NCCL_COMM_ID " + "environment configuration."); + } + std::vector global_devices; std::vector> local_devices; local_devices.reserve(global_participating_replicas.size()); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 0906c71064e..7ff8d40b440 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -402,10 +402,26 @@ std::vector NVPTXCompiler::CompileGpuAsmOrGetCachedResult( "using $PATH.", hlo_module_config); } + CHECK(hlo_module_config.debug_options() + .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) + << "There was an error when trying to compile ptx into sass " + "code. If you want to try falling back to the GPU driver to " + "jit compile ptx, you can use the flag " + "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found." + " Use at your own risk though, it has known drawbacks like " + "increased memory consumption."; } else { LOG(ERROR) << "Error during compilation of ptx to sass: " - << maybe_cubin.status() - << ". Falling back to the GPU driver."; + << maybe_cubin.status(); + CHECK(hlo_module_config.debug_options() + .xla_gpu_unsafe_fallback_to_driver_on_ptxas_error()) + << "There was an error when trying to compile ptx into sass " + "code. Up until May 14 2020, XLA silently ignored such " + "errors and fell back to the GPU driver. This is likely to " + "trigger subtle runtime issues and is hence discouraged. " + "If you want to temporarily restore this behavior use the " + "flag --xla_gpu_unsafe_fallback_to_driver_on_ptxas_error " + "and file a bug in b/components/366096."; } // We're going to use the driver to JIT our PTX->SASS, so warn if diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index e04dba418d9..7a9845d0f49 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -235,6 +235,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gpu_copy_alone_test", + srcs = [ + "gpu_copy_alone_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc new file mode 100644 index 00000000000..1c475ab4e10 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_alone_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +namespace xla { +namespace gpu { + +namespace { + +// WARNING: This tests must be alone in its file! Otherwise, the +// error isn't caught. We expect and CUDA_ERROR_ILLEGAL_ADDRESS to be +// thrown with the old buggy code. +class CopyAloneNoOptTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test MultiOutputStore contain a MOF fusion and XLA optimizer pass + // doesn't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(CopyAloneNoOptTest, CopyTranspose) { + const char* hlo_text = R"( +HloModule mod +ENTRY main { + %param = f32[8,32,32,32,16]{4,3,2,1,0} parameter(0) + ROOT %copy = f32[8,32,32,32,16]{3,2,1,4,0} copy(f32[8,32,32,32,16]{4,3,2,1,0} %param) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + ParseAndReturnVerifiedModule(hlo_text)); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); + + CompileAndOptionallyVerifyPtx(std::move(optimized_module), + R"( +CHECK-NOT: ld.global.nc.v2 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94a4df43cf4..32a9038b15a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -707,6 +707,10 @@ Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) { return Status::OK(); } +Status HloCostAnalysis::HandleAllGather(const HloInstruction* hlo) { + return Status::OK(); +} + Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 915c4dcbe84..9fdb42185fb 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -76,6 +76,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleFft(const HloInstruction* fft) override; Status HandleTriangularSolve(const HloInstruction* hlo) override; Status HandleCholesky(const HloInstruction* hlo) override; + Status HandleAllGather(const HloInstruction* hlo) override; Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index b8e3f83b515..900b557b4dc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -47,16 +47,14 @@ StatusOr HloDCE::RunOnComputation( // computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { + auto maybe_collective_op = DynCast(instruction); if (instruction != computation->root_instruction() && instruction->user_count() == 0 && computation->IsSafelyRemovable(instruction) && (!instruction->HasSideEffect() || (remove_cross_partition_collective_ops && - ((instruction->opcode() == HloOpcode::kAllReduce && - !Cast(instruction)->constrain_layout()) || - (instruction->opcode() == HloOpcode::kAllToAll && - !Cast(instruction)->constrain_layout()) || - instruction->opcode() == HloOpcode::kCollectivePermute)))) { + (maybe_collective_op != nullptr && + !maybe_collective_op->constrain_layout())))) { dead_roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index e105ea8ce18..3dc9cc24734 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -700,6 +700,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCbrt(HloInstruction* cbrt) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[cbrt], + ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) -> ElementwiseT { + return std::pow(elem_operand, static_cast(1.0 / 3.0)); + return elem_operand.real() < 0 + ? -std::pow(-elem_operand, + static_cast(1.0 / 3.0)) + : std::pow(elem_operand, + static_cast(1.0 / 3.0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCbrt(HloInstruction* cbrt) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cbrt], + ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) { + return std::cbrt(elem_operand); + })); + return Status::OK(); + } + + Status HandleCbrt(HloInstruction* cbrt) override { + return HandleCbrt(cbrt); + } + Status HandleRsqrt(HloInstruction* rsqrt) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[rsqrt], diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 78e4d39d3fe..cd2a61d7eff 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -980,6 +980,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: // De-emphasize scalar-shaped elementwise ops -- they're generally @@ -1056,6 +1057,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGetDimensionSize: case HloOpcode::kSetDimensionSize: return kGray; + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5fc42eb5e3c..9e9c8b0913b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -388,6 +388,24 @@ StatusOr> HloInstruction::CreateFromProto( proto.outfeed_config()); break; } + case HloOpcode::kAllGather: { + absl::optional channel_id; + if (proto.channel_id() > 0) { + channel_id = proto.channel_id(); + } + + TF_RET_CHECK(proto.dimensions_size() == 1) + << "AllGather cannot have more than 1 all-gather dimensions"; + TF_RET_CHECK(all_operands().size() == 1) + << "AllGather must have a single operand"; + int64 all_gather_dimension = proto.dimensions(0); + instruction = CreateAllGather( + shape, operands(0), all_gather_dimension, + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), + proto.constrain_layout(), channel_id, proto.use_global_device_ids()); + break; + } case HloOpcode::kAllReduce: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "AllReduce should have 1 called computation but sees " @@ -807,6 +825,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: break; default: @@ -928,6 +947,15 @@ HloInstruction::CreateReducePrecision(const Shape& shape, shape, operand, exponent_bits, mantissa_bits); } +/* static */ std::unique_ptr HloInstruction::CreateAllGather( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids) { + return absl::make_unique( + shape, operand, all_gather_dimension, replica_groups, constrain_layout, + channel_id, use_global_device_ids); +} + /* static */ std::unique_ptr HloInstruction::CreateAllReduce( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, @@ -1517,6 +1545,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kParameter: case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: @@ -1565,6 +1594,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1937,6 +1967,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: @@ -1994,6 +2025,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReducePrecision: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: @@ -2381,6 +2413,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: CHECK_EQ(1, operand_count()); return true; @@ -2847,6 +2880,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); + case HloOpcode::kAllGather: + return visitor->HandleAllGather(this); case HloOpcode::kAllReduce: return visitor->HandleAllReduce(this); case HloOpcode::kAllToAll: @@ -2893,6 +2928,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSin(this); case HloOpcode::kSqrt: return visitor->HandleSqrt(this); + case HloOpcode::kCbrt: + return visitor->HandleCbrt(this); case HloOpcode::kRsqrt: return visitor->HandleRsqrt(this); case HloOpcode::kReal: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3547de0f5e3..8be7a034877 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -618,6 +618,16 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); + // Creates an all-gather op, which concats the operands of all participants + // along all_gather_dimension. The replica_groups, channel_id, and + // use_global_device_ids arguments are identical to those in all-reduce, + // except that the order of the group members determines the concatenation + // order of inputs from different participants. + static std::unique_ptr CreateAllGather( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids); + // Creates a cross replica reduction op. // // `reduction_computation`: the reduction function. @@ -1605,6 +1615,9 @@ class HloInstruction { virtual int64 dimensions(int64 index) const { LOG(FATAL) << "Unimplemented method."; } + virtual std::vector* mutable_dimensions() { + LOG(FATAL) << "Unimplemented method."; + } // Delegates to HloConcatenateInstruction::concatenate_dimension. int64 concatenate_dimension() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index eb821d40e78..d5bdd674563 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -556,6 +556,51 @@ bool HloCollectiveInstruction::IdenticalSlowPath( }); } +HloAllGatherInstruction::HloAllGatherInstruction( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids) + : HloCollectiveInstruction(HloOpcode::kAllGather, shape, {operand}, + replica_groups, constrain_layout, channel_id), + all_gather_dimension_(all_gather_dimension), + use_global_device_ids_(use_global_device_ids) {} + +std::vector HloAllGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); + result.push_back(StrCat("dimensions={", all_gather_dimension_, "}")); + if (use_global_device_ids_) { + result.push_back("use_global_device_ids=true"); + } + return result; +} + +std::unique_ptr +HloAllGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands[0], all_gather_dimension(), replica_groups(), + constrain_layout(), channel_id(), use_global_device_ids()); +} + +HloInstructionProto HloAllGatherInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + proto.add_dimensions(all_gather_dimension_); + return proto; +} + +bool HloAllGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && + all_gather_dimension_ == casted_other.all_gather_dimension() && + use_global_device_ids() == casted_other.use_global_device_ids(); +} + HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 3b1916e9486..ae78d365cfa 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -348,6 +348,38 @@ class HloCollectiveInstruction : public HloChannelInstruction { bool constrain_layout_; }; +class HloAllGatherInstruction : public HloCollectiveInstruction { + public: + explicit HloAllGatherInstruction( + const Shape& shape, HloInstruction* operand, int64 all_gather_dimension, + const std::vector& replica_groups, bool constrain_layout, + const absl::optional& channel_id, bool use_global_device_ids); + // Same as HloAllReduceInstruction::use_global_device_ids. + bool use_global_device_ids() const { return use_global_device_ids_; } + + // The dimension on which data from different participants are concatenated. + int64 all_gather_dimension() const { return all_gather_dimension_; } + + protected: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + int64 all_gather_dimension_; + bool use_global_device_ids_; +}; + class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( @@ -465,6 +497,7 @@ class HloReverseInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -491,6 +524,7 @@ class HloConcatenateInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Accessor for the dimension in which a concatenate HLO should occur. int64 concatenate_dimension() const { return dimensions(0); } // Returns a serialized representation of this instruction. @@ -520,6 +554,7 @@ class HloReduceInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -560,6 +595,7 @@ class HloSortInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns the sort dimension for this instruction int64 sort_dimension() const { return dimensions(0); } // Returns a serialized representation of this instruction. @@ -594,6 +630,7 @@ class HloTransposeInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; // Returns a serialized representation of this instruction. @@ -621,6 +658,7 @@ class HloBroadcastInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -668,6 +706,7 @@ class HloMapInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } + std::vector* mutable_dimensions() { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index de65ed99303..9722d5c2b76 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -420,6 +420,8 @@ StatusOr HloModule::CreateModuleConfigFromShape( if (execution_options->num_partitions() > 0) { module_config.set_num_partitions(execution_options->num_partitions()); } + module_config.set_use_spmd_partitioning( + execution_options->use_spmd_partitioning()); if (execution_options->has_device_assignment()) { TF_ASSIGN_OR_RETURN(std::unique_ptr device_assignment, DeviceAssignment::Deserialize( diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index b31a9ae6ca5..833d0fe59d0 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -128,6 +128,11 @@ class HloModuleConfig { } int64 num_partitions() const { return num_partitions_; } + void set_use_spmd_partitioning(bool use_spmd_partitioning) { + use_spmd_partitioning_ = use_spmd_partitioning; + } + bool use_spmd_partitioning() const { return use_spmd_partitioning_; } + // Return a string which unambiguously represents all the fields of this data // structure. Used for generating a cache key for storing the compiled // executable. @@ -199,6 +204,14 @@ class HloModuleConfig { std::vector>* mutable_dot_config() { return &dot_config_; } + absl::Span>> layout_config() const { + return layout_config_; + } + + std::vector>>* mutable_layout_config() { + return &layout_config_; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -216,6 +229,10 @@ class HloModuleConfig { // The number of partitions (model parallelism) to compile this binary for. int64 num_partitions_ = 1; + // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA + // needs to partition the module. + bool use_spmd_partitioning_ = false; + // The target maximum parallelism at which to partition HLOs for parallel // execution on the CPU backend. int64 intra_op_parallelism_threads_ = -1; @@ -232,6 +249,9 @@ class HloModuleConfig { FusionConfigCollection fusion_config_collection_ = FusionConfigCollection::kOff; + // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto + // similar to backend config. + // Custom fusion configuration, where fusion_config_[c][v] control if node v // in computation c must be fused to all its consumers (true) or not (false). std::vector> fusion_config_; @@ -240,6 +260,10 @@ class HloModuleConfig { // how to convert dot operation v (sorted topologically and by computation) to // convolution. std::vector> dot_config_; + + // Layout configuration, where layout_config_[v][i] controls the layout + // decision i of operation v. + std::vector>> layout_config_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index dfe68d93f30..664fa10a990 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -48,6 +48,7 @@ namespace xla { V(kAdd, "add", 2) \ V(kAddDependency, "add-dependency", 2) \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kAllGather, "all-gather", 1) \ V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \ V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \ V(kAtan2, "atan2", 2) \ @@ -138,6 +139,7 @@ namespace xla { V(kSlice, "slice", 1) \ V(kSort, "sort", kHloOpcodeIsVariadic) \ V(kSqrt, "sqrt", 1) \ + V(kCbrt, "cbrt", 1) \ V(kSubtract, "subtract", 2) \ V(kTanh, "tanh", 1) \ V(kTrace, "trace", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 4162c5d62d5..2a90c95850c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -784,6 +784,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -849,6 +850,35 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, HloInstruction::CreateBitcastConvert(shape, operands[0])); break; } + case HloOpcode::kAllGather: { + optional>> tmp_groups; + optional> replica_group_ids; + optional channel_id; + optional> dimensions; + optional constrain_layout; + optional use_global_device_ids; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool, + &constrain_layout}; + attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool, + &use_global_device_ids}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllGather( + shape, operands[0], dimensions->at(0), replica_groups, + constrain_layout ? *constrain_layout : false, channel_id, + use_global_device_ids ? *use_global_device_ids : false)); + break; + } case HloOpcode::kAllReduce: { optional>> tmp_groups; optional to_apply; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 7e66b4e648d..e18014a3071 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1480,6 +1480,43 @@ ENTRY CRS { )" }, +// all-gather +{ +"AllGather", +R"(HloModule AllGather + +ENTRY AllGather { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1} +} + +)" +}, +// all-gather with constrained layout +{ +"AllGatherWithLayout", +R"(HloModule AllGather + +ENTRY AllGather { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, constrain_layout=true, dimensions={1} +} + +)" +}, +// all-gather with subgroups +{ +"AllGatherWithSubgroups", +R"(HloModule AllGatherWithSubgroups + +ENTRY AllGatherWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT ag = f32[128,64]{0,1} all-gather(input), replica_groups={{0,1},{2,3}}, dimensions={1} +} + +)", +/*replica_count=*/4, +}, // all-to-all { "AllToAll", diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc new file mode 100644 index 00000000000..129091ca06f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -0,0 +1,574 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace hlo_sharding_util { + +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count) { + int64 device = 0; + int64 count = 0; + for (auto& it : device_map) { + if (it.second > count) { + count = it.second; + device = it.first; + } + } + if (top_count != nullptr) { + *top_count = count; + } + return count > 0 ? absl::optional(device) : absl::optional(); +} + +Status AssignComputationDevice(HloComputation* computation, int64 device) { + VLOG(4) << "Assigning device " << device << " to " << computation->name() + << " computation"; + for (HloInstruction* instruction : computation->instructions()) { + if (!instruction->has_sharding()) { + VLOG(4) << "Assigning device " << device << " to " << instruction->name(); + instruction->set_device_sharding(device); + } + } + return Status::OK(); +} + +absl::optional GetMostOccurringDevice( + absl::Span instructions) { + std::map device_map; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(nullptr)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + } + return SelectDominantDevice(device_map, nullptr); +} + +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor) { + int64 instruction_count = 0; + std::map device_map; + for (HloComputation* computation : computations) { + for (HloInstruction* instruction : computation->instructions()) { + int64 count = 1; + if (instruction->has_sharding()) { + for (auto& it : instruction->sharding().UsedDevices(&count)) { + // The UsedDevices() API returns a map. + device_map[it.first] += it.second; + } + } + instruction_count += count; + } + } + int64 count; + absl::optional device = SelectDominantDevice(device_map, &count); + absl::optional dominant_device; + if (device) { + double factor = + static_cast(count) / static_cast(instruction_count); + if (factor >= dominant_factor) { + dominant_device = device; + } + } + return dominant_device; +} + +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions) { + if (sharding.IsTileMaximal()) { + return sharding; + } + const int64 rank = dimensions.size(); + std::vector tile_assignment_dim(rank); + for (int64 i = 0; i < rank; ++i) { + tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]); + } + Array tile_assignment = sharding.tile_assignment(); + tile_assignment.Reshape(tile_assignment_dim); + tile_assignment.Each([&](absl::Span indices, int64* value) { + std::vector src_indices(indices.size(), -1); + for (int64 i = 0; i < indices.size(); ++i) { + src_indices[dimensions[i]] = indices[i]; + } + *value = sharding.tile_assignment()(src_indices); + }); + return HloSharding::Tile(tile_assignment); +} + +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return sharding; + } + + // In case of a tiled sharding the reshaped sharding will be a valid if the + // reshape is composed from the following operations: + // * Adding or removing dimensions with size 1. + // * Merging consecutive dimensions where only the most major is sharded. + // * Splitting a dimension to consecutive dimensions. + // * Any reshaping of unsharded dimensions. + // Note that merge and split can happen consecutively on the same dimension, + // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024 + // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks + // to make supporting such cases easy. + const Shape tile_shape = sharding.TileShape(source_shape); + std::vector target_tile_assignment_dimensions; + std::vector source_dims_stack(source_shape.rank()); + std::vector target_dims_stack(target_shape.rank()); + std::vector sharding_tile_dims_stack(source_shape.rank()); + for (int64 i = 0; i < source_shape.rank(); ++i) { + source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i); + sharding_tile_dims_stack[i] = + sharding.tile_assignment().dim(source_shape.rank() - 1 - i); + } + for (int64 i = 0; i < target_shape.rank(); ++i) { + target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i); + } + while (!source_dims_stack.empty() || !target_dims_stack.empty()) { + if (target_dims_stack.empty()) { + if (Product(sharding_tile_dims_stack) != 1) { + return absl::nullopt; + } + break; + } + int64 s_size = 1; + int64 t_size = 1; + int64 s_partitions = 1; + if (!source_dims_stack.empty()) { + s_size = source_dims_stack.back(); + source_dims_stack.pop_back(); + s_partitions = sharding_tile_dims_stack.back(); + sharding_tile_dims_stack.pop_back(); + } + t_size = target_dims_stack.back(); + target_dims_stack.pop_back(); + if (s_partitions * Product(sharding_tile_dims_stack) == 1) { + // No more partitions left. + target_tile_assignment_dimensions.push_back(1); + continue; + } + if (s_size == t_size) { + // Same dimension. + target_tile_assignment_dimensions.push_back(s_partitions); + } else if (t_size == 1) { + // Trivial dimension added. + target_tile_assignment_dimensions.push_back(1); + source_dims_stack.push_back(s_size); + sharding_tile_dims_stack.push_back(s_partitions); + } else if (s_size == 1) { + // Trivial dimension removed. + if (s_partitions != 1) { + return absl::nullopt; + } + target_dims_stack.push_back(t_size); + } else if (s_size > t_size) { + // Dimension split. + if (s_size % t_size != 0 || t_size % s_partitions != 0) { + return absl::nullopt; + } + target_tile_assignment_dimensions.push_back(s_partitions); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(1); + } else { + // Dimension merge. Also merge the source dimension with the next, and + // process it next time. + if (s_size % s_partitions != 0) { + return absl::nullopt; + } + CHECK(!source_dims_stack.empty()); + if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) { + // If the next dimension to combine is sharded, we require that the + // current dimension's shard size to be 1. Otherwise, the new shard + // would be non-contiguous. + return absl::nullopt; + } + source_dims_stack.back() *= s_size; + sharding_tile_dims_stack.back() *= s_partitions; + target_dims_stack.push_back(t_size); + } + } + Array new_tile_assignment = sharding.tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims) { + CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal()); + CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims"; + // We optimize the tile assignment on the single dimension dim in a way to + // minimize communication among devices caused by the reshard: + // +---+---+ +---+---+ +-+-+-+-+ + // | | | | 0 | | | | | | + // | 0 | 1 | +-------+ | | | | | + // | | | reshape on | 1 | reshape on | | | | | + // +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3| + // | | | | 2 | | | | | | + // | 2 | 3 | +-------+ | | | | | + // | | | | 3 | | | | | | + // +---+---+ +---+---+ +-+-+-+-+ + + std::vector tile_dims(sharding.tile_assignment().num_dimensions(), 1); + // Handle ignore dimensions. + std::vector ignore_sizes; + int64 ignore_size = 1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + int64 size = sharding.tile_assignment().dim(i); + ignore_sizes.push_back(size); + tile_dims[i] = size; + ignore_size *= size; + } + } + + using Buckets = std::vector>; + Array buckets(ignore_sizes, + Buckets(sharding.tile_assignment().dim(dim))); + sharding.tile_assignment().Each( + [&](absl::Span index, int64 device) { + std::vector ignore_index; + for (int64 i = 0; i < index.size(); ++i) { + if (absl::c_find(dims, i) == dims.end()) { + ignore_index.push_back(index[i]); + } + } + buckets(ignore_index)[index[dim]].push_back(device); + }); + std::vector devices; + buckets.Each([&](absl::Span index, const Buckets& buckets) { + for (auto& bucket : buckets) { + devices.insert(devices.end(), bucket.begin(), bucket.end()); + } + }); + tile_dims[dim] = devices.size() / ignore_size; + Array tile_assignment(tile_dims); + tile_assignment.SetValues(devices); + return HloSharding::Tile(tile_assignment); +} + +bool ContainsTileSharding(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->has_sharding() && + !instruction->sharding().IsTileMaximal()) { + return true; + } + } + } + return false; +} + +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector output_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.offset_dims(), i)) { + output_tile_assignment_dims.push_back(1); + } else { + output_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(output_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo) { + if (output_sharding.IsTileMaximal()) { + return output_sharding; + } + + const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + index_tile_assignment_dims.push_back( + output_sharding.tile_assignment().dim(i)); + } + } + Array new_tile_assignment = output_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) { + if (hlo.sharding().IsTileMaximal()) { + return hlo.sharding(); + } + + const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers(); + std::vector tile_assignment_dims(hlo.shape().rank()); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i); + num_elements *= hlo.sharding().tile_assignment().dim(i); + } else { + tile_assignment_dims[i] = 1; + } + } + if (num_elements == hlo.sharding().tile_assignment().num_elements()) { + // Output sharding is only on non offset dimensions. We use output sharding + // to shard this gather op directly. + return hlo.sharding(); + } + + if (num_elements == 1) { + // Output sharding is only on offset dimensions. We do not shard this gather + // op. Return a tile maximal sharding with the first device in output + // sharding tile assignment. + return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin()); + } + + // Output sharding is on both offset and non offset dimensions. We shard the + // gather op only on non offset dimensions. + // For example: + // - the gather op has sharding [2,2]{0,1,2,3}, + // - first dimension is non offset dimension, + // - second dimension is offset dimension, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(hlo.shape().rank(), 0LL), + slice_limits(hlo.shape().rank()); + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.offset_dims(), i)) { + slice_limits[i] = hlo.sharding().tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector index_tile_assignment_dims; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_binary_search(dnums.update_window_dims(), i)) { + index_tile_assignment_dims.push_back( + data_sharding.tile_assignment().dim(i)); + } + } + if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) { + index_tile_assignment_dims.push_back(1); + } + Array new_tile_assignment = data_sharding.tile_assignment(); + new_tile_assignment.Reshape(index_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers(); + std::vector data_tile_assignment_dims; + for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { + if (absl::c_binary_search(dnums.update_window_dims(), i)) { + data_tile_assignment_dims.push_back(1); + } else { + data_tile_assignment_dims.push_back( + index_sharding.tile_assignment().dim(index_dim)); + index_dim++; + } + } + Array new_tile_assignment = index_sharding.tile_assignment(); + new_tile_assignment.Reshape(data_tile_assignment_dims); + return HloSharding::Tile(new_tile_assignment); +} + +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo) { + if (index_sharding.IsTileMaximal()) { + return index_sharding; + } + + // Only shard on first "number of scatter_window_dims" dimensions. + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + int64 num_elements = 1; + int64 index_dim = 0; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + num_elements *= index_sharding.tile_assignment().dim(index_dim); + index_dim++; + } + } + if (num_elements == index_sharding.tile_assignment().num_elements()) { + // Index sharding is only on scatter_window_dims. We use this index sharding + // directly. + return index_sharding; + } + + // Index sharding is only on update_window_dims. We do not shard this scatter + // op. Return a tile maximal sharding with the first device in index sharding + // tile assignment. + if (num_elements == 1) { + return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin()); + } + + const int64 index_rank = hlo.operand(1)->shape().rank(); + std::vector slice_starts(index_rank, 0LL), slice_limits(index_rank); + for (int64 i = 0; i < index_rank; ++i) { + if (i < index_dim) { + slice_limits[i] = index_sharding.tile_assignment().dim(i); + } else { + slice_limits[i] = 1; + } + } + Array tile_assignment = + index_sharding.tile_assignment().Slice(slice_starts, slice_limits); + return HloSharding::Tile(tile_assignment); +} + +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo) { + if (data_sharding.IsTileMaximal()) { + return data_sharding; + } + + const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers(); + const int64 data_rank = hlo.operand(2)->shape().rank(); + std::vector tile_assignment_dims(data_rank, 1LL); + int64 num_elements = 1; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_binary_search(dnums.inserted_window_dims(), i)) { + CHECK_LT(i, data_rank); + tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i); + num_elements *= data_sharding.tile_assignment().dim(i); + } + } + if (num_elements == data_sharding.tile_assignment().num_elements()) { + // Data sharding is only on scatter_window_dims. We use this data sharding + // directly. + return data_sharding; + } + + if (num_elements == 1) { + // Data sharding is only on update_window_dims. We do not shard this + // scatter op. Return a tile maximal sharding with the first device in + // data sharding tile assignment. + return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin()); + } + + // Data sharding is on both update_window_dims and scatter_window_dims. We + // shard the scatter op only on scatter_window_dims. For example: + // - the scatter data has sharding [2,2]{0,1,2,3}, + // - first dimension is scatter_window_dims, + // - second dimension is update_window_dims, + // Then the result sharding will be [2,1]{0,2}. + std::vector slice_starts(data_rank, 0LL); + Array tile_assignment = + data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims); + return HloSharding::Tile(tile_assignment); +} + +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter) { + auto computation = scatter.to_apply(); + // We only handle computations with 2 parameters and only 1 calculation. + if (computation->instruction_count() != 3) { + return Status( + tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation with 2 parameters and only 1 " + "calculation"); + } + + auto root_instruction = computation->root_instruction(); + if (root_instruction->opcode() == HloOpcode::kAdd || + root_instruction->opcode() == HloOpcode::kOr) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMultiply || + root_instruction->opcode() == HloOpcode::kAnd) { + return std::make_pair(HloInstruction::CreateConstant( + LiteralUtil::One(scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMaximum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } else if (root_instruction->opcode() == HloOpcode::kMinimum) { + return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue( + scatter.shape().element_type())), + root_instruction->opcode()); + } + + return Status(tensorflow::error::Code::INVALID_ARGUMENT, + "Expected scatter reduce computation which is " + "add/or/multiply/add/min/max"); +} + +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices) { + std::vector devices; + if (sharding.IsReplicated()) { + for (int64 d : available_devices) { + if (!HloSharding::IsReservedDevice(d)) { + devices.push_back(d); + } + } + return devices; + } + + for (int64 i : available_devices) { + if (sharding.UsesDevice(i)) { + devices.push_back(i); + } + } + DCHECK(std::all_of(sharding.tile_assignment().begin(), + sharding.tile_assignment().end(), [&](int64 device) { + return std::find(available_devices.begin(), + available_devices.end(), + device) != available_devices.end(); + })); + return devices; +} + +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h new file mode 100644 index 00000000000..00d9434a34d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace hlo_sharding_util { + +// Given a map, selects the device with higher +// occurrence count (if any). If top_count in not nullptr, it will receive the +// count of the dominant device returned. +absl::optional SelectDominantDevice( + const std::map& device_map, int64* top_count); + +// Assigns all the instructions of a computation, to a given device. +// This API does not recurse into called computations, and does not assign +// instructions which already have sharding. +Status AssignComputationDevice(HloComputation* computation, int64 device); + +// Given an instruction container, returns the device which is most commonly +// occurring among the instructions. +absl::optional GetMostOccurringDevice( + absl::Span instructions); + +// Given a set of computations, tries to extract the dominant device. A device +// is dominant if the combined occurrence among all the instructions of the +// input computations, is greater/equal than/to dominant_factor (real number +// from 0 to 1). +// This API does not recurse into called computations. +// If no device exists that satisfies the condition, the returned optional will +// hold no value. +StatusOr> GetDominantDevice( + absl::Span computations, double dominant_factor); + +// Returns the HloSharding with the tile dimensions and tile assignment +// transposed based on the specified dimension numbers. In case of a tile +// maximal sharding returns the original sharding. +HloSharding TransposeSharding(const HloSharding& sharding, + const std::vector& dimensions); + +// Returns the HloSharding with the tile shape reshaped based on the source and +// target shapes and the tile assignment adjusted to correspond to the new tile +// shape or absl::nullopt if the resulting reshape would create an invalid +// sharding (non continuous or non uniformly sized tiles). In case of a tile +// maximal sharding returns the original sharding. +absl::optional ReshapeSharding(const Shape& source_shape, + const Shape& target_shape, + const HloSharding& sharding); + +// Returns a sharding tiled on unique dimension dim by reshaping the tile +// assignment of the sharding argument. Only dimensions in the dims span +// argument are considered for reshaping, the others are ignored. +// Assumptions: sharding is tile sharded, and dim must be included in dims. +HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, + absl::Span dims); + +// Returns true if the provided module includes one or more instructions with +// a tile sharding. +bool ContainsTileSharding(const HloModule& module); + +// Returns the preferred output sharding for a gather op based on the sharding +// of the indces. +HloSharding GatherOutputSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns the preferred index sharding for a gather op based on the sharding +// of the output. +HloSharding GatherIndexSharding(const HloSharding& output_sharding, + const HloInstruction* hlo); + +// Returns a new HloSharding for a gather op so that only non offset dimensions +// are sharded. Assume "result" is returned by this function. It is ensured that +// "GetIndexSharding(result, hlo)" will have the same number of elements as +// "result". +HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo); + +// Returns the preferred index sharding for a scatter op based on the sharding +// of the data. +HloSharding ScatterIndexSharding(const HloSharding& data_sharding, + const HloInstruction* hlo); + +// Returns the preferred data sharding for a scatter op based on the sharding +// of the index. +HloSharding ScatterDataSharding(const HloSharding& index_sharding, + const HloInstruction* hlo); + +// Returns a new index sharding for a scatter op so that we only shard on first +// "number of scatter_window_dims" dimensions. Assume "result" is returned by +// this function. It is ensured that "ScatterDataSharding(result, hlo)" will +// have the same number of elements as "result". +HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, + const HloInstruction& hlo); + +// Returns a new data sharding for a scatter op so that we only shard on +// scatter_window_dims. Assume "result" is returned by this function. It is +// ensured that "ScatterIndexSharding(result, hlo)" will have the same number of +// elements as "result". +HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, + const HloInstruction& hlo); + +// Returns an identity value and an HloOpcode for reduce computation of scatter +// instruction. +// - If computation is add/or, return 0/false with corresponding op code; +// - If computation is multiply/and, return 1/true with corresponding op code. +// - If computation is min/max, return max value/min value with corresponding op +// code. +// - Otherwise, return error status. +StatusOr, HloOpcode>> +IdentityValueAndHloOpcodeForScatterReduceComputation( + const HloScatterInstruction& scatter); + +// Given a sharding and a list of devices in the topology, return a +// list of the devices that `sharding` applies to. +std::vector DevicesForSharding( + const HloSharding& sharding, const std::vector& available_devices); + +} // namespace hlo_sharding_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc new file mode 100644 index 00000000000..02496c75965 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc @@ -0,0 +1,206 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace hlo_sharding_util { +namespace { + +TEST(HloShardingUtilTest, TransposeShardingReplicated) { + EXPECT_EQ(TransposeSharding(HloSharding::Replicate(), {0, 1, 2}), + HloSharding::Replicate()); +} + +TEST(HloShardingUtilTest, TransposeShardingTiled) { + HloSharding input = HloSharding::Tile(Array4D({{{{0, 1}}, {{2, 3}}}})); + HloSharding output = + HloSharding::Tile(Array4D({{{{0}, {2}}}, {{{1}, {3}}}})); + EXPECT_EQ(TransposeSharding(input, {3, 0, 1, 2}), output); +} + +TEST(HloShardingUtilTest, ReshapeShardingMaximal) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::AssignDevice(7); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledInvalid) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {20, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledArbitraryMinorDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7, 5, 3}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 15, 2, 14}); + Array sharding_array({2, 1, 1, 1}); + sharding_array(0, 0, 0, 0) = 0; + sharding_array(1, 0, 0, 0) = 1; + HloSharding sharding = HloSharding::Tile(sharding_array); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTiledTrivialDimensions) { + Shape input_shape = ShapeUtil::MakeShape(F32, {3, 1, 5, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 1, 7}); + HloSharding input_sharding = + HloSharding::Tile(Array4D({{{{0}, {1}}}})); + HloSharding output_sharding = + HloSharding::Tile(Array4D({{{{0}}, {{1}}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingTrivialDImensionInsertedToEnd) { + Shape input_shape = ShapeUtil::MakeShape(F32, {8, 16}); + Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 1}); + HloSharding input_sharding = HloSharding::Tile(Array2D({{0}, {1}})); + HloSharding output_sharding = + HloSharding::Tile(Array3D({{{0}}, {{1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, NoopReshapeShardingEmptyTile) { + Shape shape = ShapeUtil::MakeShape(F32, {7, 1, 1}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = ReshapeSharding(shape, shape, sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingScalar) { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloSharding sharding = HloSharding::Tile(Array3D({{{0}, {1}}})); + absl::optional result = + ReshapeSharding(input_shape, output_shape, sharding); + EXPECT_FALSE(result.has_value()); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim0) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0}, {1}, {2}, {3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim1) { + HloSharding sharding = HloSharding::Tile(Array2D({{0, 1}, {2, 3}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1}); + EXPECT_EQ(result.tile_assignment(), Array2D({{0, 2, 1, 3}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim0) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/0, /*dims=*/{0, 1, 2}); + EXPECT_EQ( + result.tile_assignment(), + Array3D({{{0}}, {{1}}, {{2}}, {{3}}, {{4}}, {{5}}, {{6}}, {{7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim1) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/1, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0}, {1}, {4}, {5}, {2}, {3}, {6}, {7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension3D_Dim2) { + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + HloSharding result = + ReshapeToTileDimension(sharding, /*dim=*/2, /*dims=*/{0, 1, 2}); + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 4, 6, 1, 3, 5, 7}}})); +} + +TEST(HloShardingUtilTest, ReshapeToTileDimension2D_Dim2_Batch1) { + // Tile sharding in batch dimension, i.e. + // sharding={devices[2,2,2]0,1,2,3,4,5,6,7,8}. + HloSharding sharding = + HloSharding::Tile(Array3D({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}})); + // Reshape on dimensions {1, 2} only, therefore ignoring batch dimension 0. + HloSharding result = ReshapeToTileDimension(sharding, /*dim=*/2, + /*dims=*/{1, 2}); + // Expected result is {devices=[2,1,4]0,2,1,3,4,6,5,7}, i.e. the two + // non-batch dimensions {{0, 1}, {2, 3}} and {{4, 5}, {6, 7}} are individually + // reshaped to tile dimension 2, i.e. {{0, 2, 1, 3}}, {{4, 6, 5, 7}}. + EXPECT_EQ(result.tile_assignment(), + Array3D({{{0, 2, 1, 3}}, {{4, 6, 5, 7}}})); +} + +} // namespace +} // namespace hlo_sharding_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0911af10f38..d15a36532eb 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -236,6 +236,40 @@ static Status CheckReplicaGroups(HloInstruction* hlo) { return Status::OK(); } +Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) { + auto ag = Cast(hlo); + TF_RETURN_IF_ERROR(CheckReplicaGroups(ag)); + TF_RET_CHECK(ag->all_gather_dimension() >= 0); + TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank()); + TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank()); + if (ag->use_global_device_ids() && ag->replica_groups().empty()) { + return InternalError( + "Replica group must be specified when use_global_device_ids is true"); + } + + int64 shard_count = CeilOfRatio( + ag->shape().dimensions(ag->all_gather_dimension()), + ag->operand(0)->shape().dimensions(ag->all_gather_dimension())); + if (ag->channel_id().has_value()) { + if (ag->use_global_device_ids()) { + TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); + } else { + if (ag->replica_groups().empty() || + ag->replica_groups()[0].replica_ids_size() != 1) { + return InternalError( + "Replica group size must be 1 when use_global_device_ids is " + "false if the all-gather is also cross-partition"); + } + } + } else if (!ag->replica_groups().empty()) { + // Cross-replica all-gather: shard count is subgroup size. + TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size()); + } + return CheckShape(ag, ShapeInference::InferAllGatherShape( + ag->operand(0)->shape(), ag->all_gather_dimension(), + shard_count)); +} + Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { TF_RETURN_IF_ERROR(CheckReplicaGroups(crs)); @@ -628,9 +662,11 @@ Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { shape_size_function_(bitcast->operand(0)->shape())) { return InternalError( "Bitcast cannot have different shape sizes of output (%d) and operand " - "(%d)", + "(%d) (%s) (%s)", shape_size_function_(bitcast->shape()), - shape_size_function_(bitcast->operand(0)->shape())); + shape_size_function_(bitcast->operand(0)->shape()), + bitcast->shape().ToString(true), + bitcast->operand(0)->shape().ToString(true)); } return Status::OK(); } @@ -697,11 +733,7 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { } for (HloInstruction* fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); - // Since fusion buffers aren't materialized, fusion parameters will not have - // the same memory space as the fusion operand. - if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape(), - /*minor_to_major_only=*/false, - /*ignore_memory_space=*/true)) { + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( "Shape mismatch between parameter number %d and its operand in " "%s.", diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 2e83361a591..7a2d3dc2e6c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -56,6 +56,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCholesky(HloInstruction* hlo) override; Status HandleTriangularSolve(HloInstruction* hlo) override; + Status HandleAllGather(HloInstruction* hlo) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 53938a489f1..1bc3d24274c 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -145,6 +145,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCholesky: case HloOpcode::kConditional: case HloOpcode::kConvolution: + case HloOpcode::kAllGather: case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: @@ -175,6 +176,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kSendDone: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kTriangularSolve: diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 3c35fda55f1..9e4bdeb2b2d 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -203,7 +203,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { std::unique_ptr GetStreamImplementation() override { - return std::unique_ptr(new host::HostStream()); + return std::unique_ptr( + new host::HostStream(/*thread_stack_size=*/0)); } std::unique_ptr GetTimerImplementation() override { diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 64390e77ddb..13699f3adf9 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -951,7 +951,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { if (!Shape::Equal() .IgnoreDynamicDimension() .MinorToMajorOnlyInLayout()(instruction_subshape, - buffer->shape())) { + buffer->shape()) && + instruction->opcode() != HloOpcode::kBitcast) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -1798,13 +1799,6 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // potential bugs in the layout assignment pass that may accidentally use the // existing layout. for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString()); - } // Some instructions carry mandatory layouts in their shape. if (instruction->opcode() != HloOpcode::kInfeed && !IsLayoutConstrainedCustomCall(instruction) && @@ -2179,6 +2173,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConditional: case HloOpcode::kConvert: case HloOpcode::kCos: + case HloOpcode::kAllGather: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: case HloOpcode::kDivide: @@ -2220,6 +2215,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kPopulationCount: diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 304a80c7a52..6e575247e6b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -814,27 +814,6 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } -TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { - auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( - {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); - builder.AddInstruction( - HloInstruction::CreateBitcast(constant0->shape(), constant0)); - auto m = CreateNewVerifiedModule(); - m->AddEntryComputation(builder.Build()); - - ComputationLayout computation_layout( - m->entry_computation()->ComputeProgramShape()); - LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(m.get()).status(); - EXPECT_FALSE(error_status.ok()); - EXPECT_THAT( - error_status.error_message(), - ::testing::HasSubstr( - "Unexpected bitcast operation seen during layout assignment")); -} - TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { // Pin non matching layouts to parameter and root. const char* module_str = R"( diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index ef8ddfc1a76..c80646e0c70 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -112,6 +112,8 @@ ExecutionOptions CreateExecutionOptions( } execution_options.set_num_replicas(build_options.num_replicas()); execution_options.set_num_partitions(build_options.num_partitions()); + execution_options.set_use_spmd_partitioning( + build_options.use_spmd_partitioning()); if (build_options.has_device_assignment()) { TF_CHECK_OK(build_options.device_assignment().Serialize( execution_options.mutable_device_assignment())); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index d5a118c00dc..742de71e74c 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -585,23 +585,35 @@ void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( // definition_time: int. Logical time this value was defined in the schedule. // use_times: string. This is a semicolon-separated list of integers for all // the use times. + // use_names: string. This is a semicolon-separated list of string + // representation of uses. if (debug_str->empty()) { // Append the column names. absl::StrAppend(debug_str, - "buffer_id,buffer_name,alt_mem_benefit,size,definition_" - "time,use_times\n"); + "buffer_id,buffer_name,alt_mem_benefit,size," + "definition_time,use_times,use_names\n"); } const HloBuffer& buffer = alias_analysis_.GetBufferContainingValue(*interval.buffer); const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); int64 definition_time = instruction_schedule.at(interval.buffer->defining_position().instruction); - std::set use_times; + std::vector> uses; for (const HloValue* value : buffer.values()) { for (const HloUse& use : value->uses()) { - use_times.insert(instruction_schedule.at(use.instruction)); + uses.push_back( + {instruction_schedule.at(use.instruction), use.ToString()}); } } + absl::c_sort(uses); + std::vector use_times; + std::vector use_names; + use_times.reserve(uses.size()); + use_names.reserve(uses.size()); + for (auto use : uses) { + use_times.push_back(use.first); + use_names.push_back(use.second); + } absl::StrAppend(debug_str, buffer.id(), ","); absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\","); @@ -612,7 +624,8 @@ void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ","); absl::StrAppend(debug_str, interval.size, ","); absl::StrAppend(debug_str, definition_time, ","); - absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\""); + absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\","); + absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\""); absl::StrAppend(debug_str, "\n"); } @@ -1820,24 +1833,30 @@ MemorySpaceAssignment::Run(HloModule* module, MemorySpaceAssignment memory_space_assignment(module, options, hlo_live_range); - TF_RETURN_IF_ERROR(memory_space_assignment.FindAllocationSequence( - hlo_live_range, alias_analysis)); - TF_RETURN_IF_ERROR(memory_space_assignment.Process()); - memory_space_assignment.ScheduleAsynchronousCopies(); - TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph()); - TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule()); - TF_RETURN_IF_ERROR(memory_space_assignment.ExportAndColorBuffers()); + return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range, + alias_analysis); +} + +StatusOr> +MemorySpaceAssignment::RunMemorySpaceAssignment( + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis) { + TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis)); + TF_RETURN_IF_ERROR(Process()); + ScheduleAsynchronousCopies(); + TF_RETURN_IF_ERROR(SimplifyGraph()); + TF_RETURN_IF_ERROR(FixSchedule()); + TF_RETURN_IF_ERROR(ExportAndColorBuffers()); VLOG(3) << "Module after memory space assignment: "; - XLA_VLOG_LINES(3, module->ToString()); - TF_CHECK_OK(module->schedule().Verify()); + XLA_VLOG_LINES(3, module_->ToString()); + TF_CHECK_OK(module_->schedule().Verify()); VLOG(1) << "Maximum number of outstanding async copies: " - << CountMaximumOutstandingAsyncCopies(*module); + << CountMaximumOutstandingAsyncCopies(*module_); - TF_RETURN_IF_ERROR( - memory_space_assignment.VerifyAndExportHeapSimulatorTrace()); + TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace()); - return std::move(memory_space_assignment.preset_assignments_); + return std::move(preset_assignments_); } Status MemorySpaceAssignment::FindAllocationSequence( diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index ab4bc5bf106..eb16db90600 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -604,6 +604,8 @@ class MemorySpaceAssignment { AllocationSequence allocation_sequence_; }; + virtual ~MemorySpaceAssignment() = default; + // Runs the MemorySpaceAssignment pass. static StatusOr> Run( HloModule* module, const HloLiveRange& hlo_live_range, @@ -621,13 +623,19 @@ class MemorySpaceAssignment { Status VerifyAndExportHeapSimulatorTrace(); protected: + // Main driver of the memory space assignment pass. + virtual StatusOr> RunMemorySpaceAssignment( + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis); + // Finds an AllocationSequence for placing buffers in alternate memory using // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is // called. - Status FindAllocationSequence(const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis); + virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis); + + Options options() const { return options_; } - private: MemorySpaceAssignment(HloModule* module, Options options, const HloLiveRange& hlo_live_range) : module_(module), @@ -646,6 +654,9 @@ class MemorySpaceAssignment { } } + AllocationSequence allocations_; + + private: // Process calls Process methods of the allocations after the allocations have // been finalized. Status Process(); @@ -682,7 +693,6 @@ class MemorySpaceAssignment { Options options_; std::vector flattened_instructions_; absl::flat_hash_set computations_in_schedule_; - AllocationSequence allocations_; std::unique_ptr preset_assignments_; std::vector> alternate_memory_assignments_; int64 alternate_memory_size_ = 0; diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.cc b/tensorflow/compiler/xla/service/memory_space_propagation.cc new file mode 100644 index 00000000000..80eb4017477 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation.cc @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_propagation.h" + +namespace xla { + +StatusOr MemorySpacePropagation::Run(HloModule* module) { + bool modified = false; + TF_ASSIGN_OR_RETURN(auto dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + dataflow_analysis_ = std::move(dataflow_analysis); + + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kFusion) { + // Propagate the operand subshapes. + for (int operand_idx = 0; operand_idx < instruction->operand_count(); + ++operand_idx) { + modified |= + PropagateSubshapes(instruction->operand(operand_idx)->shape(), + instruction->fused_parameter(operand_idx)); + } + + // Propagate output subshapes. + modified |= PropagateSubshapes(instruction->shape(), + instruction->fused_expression_root()); + } + } + } + return modified; +} + +bool MemorySpacePropagation::PropagateSubshapes( + const Shape& caller_shape, const HloInstruction* callee_instruction) const { + bool modified = false; + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(caller_shape)) { + int64 memory_space = indexed_shape.shape.layout().memory_space(); + const HloValue& value = dataflow_analysis_->GetUniqueValueAt( + callee_instruction, indexed_shape.index); + + for (const HloPosition& position : value.positions()) { + Shape* shape = ShapeUtil::GetMutableSubshape( + position.instruction->mutable_shape(), position.index); + if (shape->layout().memory_space() != memory_space) { + shape->mutable_layout()->set_memory_space(memory_space); + modified = true; + } + } + } + return modified; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.h b/tensorflow/compiler/xla/service/memory_space_propagation.h new file mode 100644 index 00000000000..65a1dfd14a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This is a legalization pass that propagates the memory space in the layout to +// the fusion computations. +class MemorySpacePropagation : public HloModulePass { + public: + ~MemorySpacePropagation() override = default; + absl::string_view name() const override { return "memory-space-propagation"; } + StatusOr Run(HloModule* module) override; + + private: + // Given the caller shape (operand or output) and its corresponding + // insturction in the fused computation (parameter or root), propagates the + // memory space to all the subshapes in the callee side. Returns true if the + // module is modified. + bool PropagateSubshapes(const Shape& caller_shape, + const HloInstruction* callee_instruction) const; + + std::unique_ptr dataflow_analysis_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc new file mode 100644 index 00000000000..8d74958f6aa --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_propagation.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class MemorySpacePropagationTest : public HloTestBase { + public: + MemorySpacePropagationTest() + : HloTestBase(), + verifier_(/*layout_sensitive=*/false, /*allow_mixed_precision*/ false) { + } + + Status Verify(HloModule* module) { return verifier_.Run(module).status(); } + + private: + HloVerifier verifier_; +}; + +TEST_F(MemorySpacePropagationTest, NoMemorySpace) { + absl::string_view hlo_string = R"( + HloModule NoMemorySpace + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)} copy(%param2) + %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_FALSE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_ASSERT_OK_AND_ASSIGN(auto ref, ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, NonTupleOutput) { + absl::string_view hlo_string = R"( + HloModule NonTupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule NonTupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, TupleOutput) { + absl::string_view hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + %gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0 + %gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1 + ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule TupleOutput + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + %multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %tuple = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0) + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + %gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0 + %gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1 + ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index cd679f7412e..a57e4300d6e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -185,11 +185,11 @@ cc_library( "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LoopOps", - "@llvm-project//mlir:LoopOpsTransforms", "@llvm-project//mlir:LoopsToGPUPass", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 33d3690d4ab..847ad918308 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -31,9 +31,9 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/LoopOps.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/Passes.h" // from @llvm-project -#include "mlir/Dialect/LoopOps/Transforms.h" // from @llvm-project +#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project @@ -45,6 +45,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -60,34 +61,6 @@ namespace { using ::mlir::xla_lhlo::FusionOp; -// Following are some small transformations that are required to clean up code -// after lowering from linalg to loops. - -// A simple pass that applies lowering of HLO to LHLO only within LHLO ops that -// contain regions with HLO ops, e.g. FusionOp, ReduceOp, SelectAndScatterOp. -// This is needed, as these ops are not closed from above and hence nested pass -// managers can not be applied. -struct NestedHloRegionsConverter - : public mlir::PassWrapper { - void runOnFunction() override { - auto& ctx = getContext(); - mlir::OwningRewritePatternList patterns; - mlir::ConversionTarget target(ctx); - target.addLegalDialect<::mlir::xla_lhlo::XlaLhloDialect>(); - ::mlir::xla_hlo::populateHLOToLHLOConversionPattern(&ctx, &patterns); - - getFunction().walk([&](mlir::Operation* op) { - if (op->getNumRegions() == 0) { - return; - } - if (failed(applyPartialConversion(op, target, patterns, nullptr))) { - signalPassFailure(); - } - }); - } -}; - // Replaces a FusionOp by the operations contained in its region. struct FusionOpRemover : public mlir::PassWrapper { @@ -132,7 +105,7 @@ struct StoreForwardingPass // No store operation found. Continue search outside of the parallel // loop if block is in a parallel loop. if (auto parallelOp = - llvm::dyn_cast(block->getParentOp())) { + llvm::dyn_cast(block->getParentOp())) { return findStore(parallelOp.getOperation(), matches); } return {}; @@ -388,8 +361,8 @@ struct MapParallelLoops struct FuseInnerParallelLoops : public mlir::PassWrapper { void runOnFunction() override { - getFunction().walk([](mlir::loop::ParallelOp op) { - mlir::loop::naivelyFuseParallelOps(op.region()); + getFunction().walk([](mlir::scf::ParallelOp op) { + mlir::scf::naivelyFuseParallelOps(op.region()); }); } }; @@ -401,7 +374,7 @@ struct ParallelLoopCollapsingToFirstDim void runOnOperation() override { mlir::Operation* module = getOperation(); - module->walk([&](mlir::loop::ParallelOp op) { + module->walk([&](mlir::scf::ParallelOp op) { unsigned num_loops = op.getNumLoops(); std::vector combinedLoops; combinedLoops.reserve(num_loops); @@ -436,8 +409,10 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, tiling_for_unrolling.append(tile_sizes.begin(), tile_sizes.end()); } - // First, lower bodies of LHLO operations that contain HLO ops. - pm.addPass(absl::make_unique()); + // Legalize from HLO to LHLO. + pm.addPass(::mlir::xla_hlo::createLegalizeToLhloPass()); + // Moving `AllocOp`s and inserting missing `DeallocOp`s + pm.addPass(::mlir::createBufferPlacementPass()); // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); // Remove unnecessary LHLO copies. diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD index 014b26c5c78..850d5f5a0cf 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -22,6 +22,7 @@ glob_lit_tests( default_tags = tf_cuda_tests_tags() + [ "no_pip", "config-cuda-only", + "no_rocm", ], driver = "@llvm-project//mlir:run_lit.sh", exclude = [ diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ab71c30dcae..2ed5e709d81 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -313,6 +313,8 @@ StatusOr> Service::CreateModuleConfig( if (execution_options->num_partitions() > 0) { config->set_num_partitions(execution_options->num_partitions()); } + config->set_use_spmd_partitioning( + execution_options->use_spmd_partitioning()); config->set_seed(execution_options->seed()); config->set_launch_id(execution_options->launch_id()); config->set_debug_options(execution_options->debug_options()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d2cbdddff2e..8d6ef9faba9 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -257,6 +257,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kLog1p: case HloOpcode::kRsqrt: case HloOpcode::kSqrt: + case HloOpcode::kCbrt: case HloOpcode::kTanh: if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { @@ -1998,6 +1999,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return a; } +/* static */ StatusOr ShapeInference::InferAllGatherShape( + const Shape& operand_shape, int64 all_gather_dimension, int64 shard_count) { + TF_RET_CHECK(all_gather_dimension > 0); + TF_RET_CHECK(all_gather_dimension < operand_shape.rank()); + TF_RET_CHECK(shard_count > 0); + auto shape = operand_shape; + shape.set_dimensions(all_gather_dimension, + shard_count * shape.dimensions(all_gather_dimension)); + return shape; +} + /* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 2e96a77aa22..2cb5930d098 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -123,6 +123,12 @@ class ShapeInference { // Infers the shape produced by the given triangular solve operation. static StatusOr InferCholeskyShape(const Shape& a); + // Infers the shape produced by an all-gather with the given operand shape, + // concat dimension, and shard count. + static StatusOr InferAllGatherShape(const Shape& operand_shape, + int64 all_gather_dimension, + int64 shard_count); + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferAllReduceShape( diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index a1872330648..b7a67b4e66e 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -93,6 +94,18 @@ class ShapedBuffer { buffers_.replace_shape_ptr(&on_device_shape_); } + // Reset the shape of this shaped buffer and underlying buffer structure. + // + // Precondition: EqualStructure(this->on_device_shape_, on_device_shape). + void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) + << "Structures are not the same. new: " << on_device_shape + << ", old: " << on_device_shape_; + on_host_shape_ = on_host_shape; + on_device_shape_ = on_device_shape; + buffers_.replace_shape_ptr(&on_device_shape_); + } + // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. const ShapeTree& buffers() const { return buffers_; } diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD new file mode 100644 index 00000000000..5be6a04f934 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -0,0 +1,69 @@ +# Description: SPMD partitioning pass. + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +cc_library( + name = "spmd_partitioner", + srcs = [ + "spmd_partitioner.cc", + "spmd_partitioner_util.cc", + ], + hdrs = [ + "spmd_partitioner.h", + "spmd_partitioner_util.h", + ], + deps = [ + "//tensorflow/compiler/xla:comparison_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_query", + "//tensorflow/compiler/xla/service:hlo_sharding_util", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/core/platform:numbers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "spmd_partitioner_test", + srcs = ["spmd_partitioner_test.cc"], + deps = [ + ":spmd_partitioner", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc new file mode 100644 index 00000000000..b857c8bdbe6 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -0,0 +1,4655 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/numbers.h" + +namespace xla { +namespace spmd { + +string SpmdLogger::MakeReport() { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory during transformation *****\n"); + + std::sort(entries_.begin(), entries_.end(), + [](auto const& entry0, auto const& entry1) { + return entry0.first > entry1.first; + }); + for (int64 i = 0; + i < std::min(report_instruction_count_, entries_.size()); ++i) { + absl::StrAppend( + &report, "\n ", + tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ", + entries_[i].second, "\n"); + } + + return report; +} + +void SpmdLogger::RegisterLogEntry(HloInstruction* hlo, + const std::vector& group) { + string report = hlo->ToString(); + int64 max_value = -1; + for (HloInstruction* inst : group) { + if (inst->shape().IsTuple()) { + continue; + } + max_value = + std::max(max_value, ShapeUtil::ByteSizeOf(inst->shape(), 4)); + absl::StrAppend(&report, " * ", inst->ToString(), "\n"); + } + entries_.push_back(std::make_pair(max_value, report)); +} + +/* static */ string SpmdLogger::ReportBeforePartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage before partition *****\n"); + absl::StrAppend(&report, "\n ** Replicated instructions\n"); + absl::StrAppend(&report, ReportMemoryUsage( + module, + [](const HloInstruction* hlo) { + return !hlo->has_sharding() || + hlo->sharding().IsReplicated(); + }, + report_instruction_count)); + absl::StrAppend(&report, "\n ** All instructions\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +/* static */ string SpmdLogger::ReportAfterPartition( + const HloModule& module, int64 report_instruction_count) { + string report; + absl::StrAppend(&report, + "\n\n***** SPMD memory usage after partition *****\n"); + absl::StrAppend(&report, + ReportMemoryUsage( + module, [](const HloInstruction* hlo) { return true; }, + report_instruction_count)); + return report; +} + +template +/* static */ string SpmdLogger::ReportMemoryUsage( + const HloModule& module, const F& filter, int64 report_instruction_count) { + string report; + std::vector instructions; + instructions.reserve(module.instruction_count()); + + for (auto computation : module.computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto hlo : computation->instructions()) { + if (hlo->shape().IsTuple() || + ShapeUtil::IsEffectiveScalar(hlo->shape())) { + continue; + } + if (filter(hlo)) { + instructions.push_back(hlo); + } + } + } + + const auto add_report = [&](std::vector* insts) { + std::sort(insts->begin(), insts->end(), + [](const HloInstruction* inst0, const HloInstruction* inst1) { + return ShapeUtil::ByteSizeOf(inst0->shape()) > + ShapeUtil::ByteSizeOf(inst1->shape()); + }); + for (int64 i = 0; + i < std::min(report_instruction_count, insts->size()); ++i) { + absl::StrAppend(&report, " ", + tensorflow::strings::HumanReadableNumBytes( + ShapeUtil::ByteSizeOf((*insts)[i]->shape())), + " : ", (*insts)[i]->ToString(), "\n"); + } + }; + + add_report(&instructions); + return report; +} + +namespace { + +// Returns the replica group configuration where each replica belongs to its own +// group. +std::vector CreateReplicaGroups(int64 num_replicas) { + std::vector groups(num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + groups[i].add_replica_ids(i); + } + return groups; +} + +bool CanReshardWithAllToAll(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) != UniqueTiledDim(target); +} + +bool CanReshardWithCollectivePermute(const HloSharding& source, + const HloSharding& target) { + return UniqueTiledDim(source) && UniqueTiledDim(target) && + UniqueTiledDim(source) == UniqueTiledDim(target) && source != target; +} + +// Clears all sharding attributes from instructions in the module. This must be +// called only after all SPMD transformation is complete. +Status ClearShardingAttributes(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + // Keep sharding annotation on Infeed and entry parameters since they're + // used by HloReplicationAnalysis later (for ArCrsCombiner). + if (hlo->opcode() == HloOpcode::kInfeed) { + continue; + } + if (hlo->opcode() == HloOpcode::kParameter && + computation == module->entry_computation()) { + continue; + } + hlo->clear_sharding(); + } + } + return Status::OK(); +} + +} // namespace + +HloInstruction* SpmdBuilder::AddInstruction( + std::unique_ptr instruction) { + HloInstruction* hlo = + HloComputation::Builder::AddInstruction(std::move(instruction)); + if (visiting_hlo_) { + instructions_[visiting_hlo_].push_back(hlo); + } + return hlo; +} + +PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first == target) { + return entry.second; + } + } + cache.emplace_back(target, ReshardNoCache(target)); + state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()] + .reshard_cache.emplace_back(sharding(), *this); + return cache.back().second; +} + +PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { + VLOG(2) << "Resharding " << hlo_->ToString() << " from " + << hlo_->sharding().ToString() << " to " << target.ToString(); + const Shape& shape = hlo_->shape(); + CHECK(shape.IsTuple() || !target.IsTuple()); + + // Tuple shape instructions may have non-tuple sharding, which means that the + // same sharding applies to all the leaves. + if (shape.IsTuple() && !target.IsTuple()) { + return Reshard(target.GetTupleSharding(shape).ValueOrDie()); + } + + // For a tuple shape, recursively apply Reshard to all the leaves and return + // a tuple instruction. + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + auto subshape = ShapeUtil::GetTupleElementShape(shape, i); + auto element = state_.b->AddInstruction( + HloInstruction::CreateGetTupleElement(subshape, hlo(), i)); + element->set_sharding(sharding().GetSubSharding(shape, {i})); + elements.push_back( + PartitionedHlo( + element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_) + .Reshard(target.GetSubSharding(shape, {i})) + .hlo()); + } + auto tuple = + state_.b->AddInstruction(HloInstruction::CreateTuple(elements)); + tuple->set_sharding(target); + return PartitionedHlo(tuple, base_shape_, state_); + } + + if (sharding() == target) { + return *this; + } + + if (shape.element_type() == TOKEN) { + return *this; + } + + if (CanReshardWithCollectivePermute(sharding(), target)) { + return ReshardWithCollectivePermute(target); + } + + if (CanReshardWithAllToAll(sharding(), target)) { + return ReshardWithAllToAll(target); + } + + // If not replicated yet, first replicate and then reshard to use one of the + // two implementations below. + if (!sharding().IsReplicated()) { + return Replicate().Reshard(target); + } + + // 'Replicated' to 'SingleDevice'. + if (target.IsTileMaximal()) { + auto copy = state_.b->AddInstruction( + HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_)); + copy->set_sharding(target); + return PartitionedHlo(copy, base_shape_, state_); + } + + // 'Replicated' to 'Tiled'. + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + auto shard_shape = MakePartitionedShape(shape, target); + auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, + MakePartitionOffsets(shape, target, state_.partition_id, state_.b), + shard_shape.dimensions())); + slice->set_sharding(target); + return PartitionedHlo(slice, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::PadWithValue(HloInstruction* pad_value) const { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) { + return *this; + } + CHECK(!sharding.IsTileMaximal()); + auto index_shape = ShapeUtil::ChangeElementType(shape, S32); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + auto get_mask_for_dim = [&](int64 dim, HloInstruction* start_index) { + // Comparison: iota + start_index < valid_size + auto iota = + state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, start_index, {})); + auto index_in_full_shape = + state_.b->AddInstruction(HloInstruction::CreateBinary( + index_shape, HloOpcode::kAdd, iota, broadcast_start_index)); + auto valid_size = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(base_shape_.dimensions(dim)))); + auto broadcast_valid_size = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(index_shape, valid_size, {})); + return state_.b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_full_shape, broadcast_valid_size, + ComparisonDirection::kLt)); + }; + + HloInstruction* mask = nullptr; + auto offsets = MakePartitionOffsets(base_shape_, sharding, + state_.partition_id, state_.b); + for (int64 i = 0; i < shape.rank(); ++i) { + if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0) { + continue; + } + if (mask == nullptr) { + mask = get_mask_for_dim(i, offsets[i]); + } else { + mask = state_.b->AddInstruction( + HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask, + get_mask_for_dim(i, offsets[i]))); + } + } + + if (mask == nullptr) { + return *this; + } + + auto broadcast_pad_value = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, pad_value, {})); + auto result = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value)); + result->set_sharding(sharding); + return PartitionedHlo(result, base_shape_, state_); +} + +absl::optional +PartitionedHlo::ReshardAsWindowedInput(const Window& window, + const HloSharding& target, + HloInstruction* pad_value, + bool mask_invalid_region) { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache; + for (auto& entry : cache) { + if (std::get<0>(entry) == target && + protobuf_util::ProtobufEquals(std::get<1>(entry), window)) { + return std::get<2>(entry); + } + } + auto update_cache = [&](WindowedInputShardReturnValue result) { + cache.emplace_back(target, window, std::move(result)); + return std::get<2>(cache.back()); + }; + VLOG(2) << "ReshardAsWindowedInput()\n" + << "\twindow:" << window_util::ToString(window) + << "\ttarget sharding:" << target.ToString(); + + CHECK(!target.IsTileMaximal()); + auto partition_ordinals = + MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b); + auto shard_shape = base_shape_; + + std::vector start_on_padded_calculations( + base_shape_.rank()); + std::vector limit_on_padded_calculations( + base_shape_.rank()); + std::vector dynamic_slice_offset_on_output( + base_shape_.rank(), nullptr); + + Window shard_window = window; + auto padded_shape = base_shape_; + std::vector offsets_on_padded_shape(base_shape_.rank()); + std::vector per_shard_window_counts(base_shape_.rank()); + std::vector explicit_left_padding(base_shape_.rank()); + for (int64 i = 0; i < base_shape_.rank(); ++i) { + // Do not pad non-partitioned dimensions. + int64 shard_count = target.tile_assignment().dim(i); + if (shard_count == 1) { + offsets_on_padded_shape[i] = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + continue; + } + const auto& wd = window.dimensions(i); + if (wd.window_dilation() != 1) { + // TODO(yuanzx): Support window dilation. + VLOG(2) << "Failed to reshard window operand due to window dilation"; + return absl::nullopt; + } + int64 full_size = + base_shape_.dimensions(i) + + (wd.base_dilation() - 1) * (base_shape_.dimensions(i) - 1) + + wd.padding_high() + wd.padding_low(); + if (full_size < wd.size()) { + VLOG(2) << "Failed to reshard window operand because the window size is " + "larger than padded base size"; + return absl::nullopt; + } + int64 window_count = (full_size - wd.size()) / wd.stride() + 1; + per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count); + if (wd.stride() != 1 && + (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) { + // TODO(yuanzx): Support this case. + VLOG(2) << "Failed to reshard window operand due to non-trivial dilation"; + return absl::nullopt; + } + + // We use explicit padding for full dilations, then use padding_low and + // padding_high on the sharded op for the remaining. padding_low and + // padding_high are now given initial values, which will be later updated if + // dilation is not 1. + auto swd = shard_window.mutable_dimensions(i); + explicit_left_padding[i] = wd.padding_low() / wd.base_dilation(); + swd->set_padding_low(wd.padding_low() % wd.base_dilation()); + swd->set_padding_high(0); + + // Calculation for the first element needed on the 'padded-but-not-dilated' + // shape. The start on the dilated shape could be a hole, so we add + // wd.base_dilation() - 1 to the constant term to skip the leading holes. + start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation()); + int64 dilated_shard_size = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation( + wd.stride() * per_shard_window_counts[i], + dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(), + wd.base_dilation()); + + offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate( + partition_ordinals[i], state_.b); + + auto shard_size_function = + limit_on_padded_calculations[i] - start_on_padded_calculations[i]; + int64 max_shard_size = shard_size_function.MaxInRange(0, shard_count); + shard_shape.set_dimensions(i, max_shard_size); + padded_shape.set_dimensions( + i, limit_on_padded_calculations[i].Calculate(shard_count - 1)); + + // For base dilation, calculate the needed padding_low and padding_high, as + // well as the offset for the output if a dynamic slice is needed after the + // sharded op. + if (wd.base_dilation() != 1) { + // Returns the offset of a shard's first valid element in the dilated + // shard. + auto get_first_valid_element_offset_on_dilated_shard = + [&](int64 shard_ordinal) { + return start_on_padded_calculations[i].Calculate(shard_ordinal) * + wd.base_dilation() + + swd->padding_low() - + wd.stride() * per_shard_window_counts[i] * shard_ordinal; + }; + CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0), + swd->padding_low()); + + // Determine swd->padding_high. + for (int64 shard_ordinal = 0; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 wanted_limit_on_dilated_shard = + wd.stride() * (per_shard_window_counts[i] - 1) + wd.size(); + int64 actual_limit_on_dilated_shard_without_pad_high = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal) + + (max_shard_size - 1) * wd.base_dilation() + 1; + swd->set_padding_high(std::max( + swd->padding_high(), + wanted_limit_on_dilated_shard - + actual_limit_on_dilated_shard_without_pad_high)); + } + + // Determine swd->padding_low and output dynamic slice index. + if (wd.stride() == 1) { + int64 max_pad_low = get_first_valid_element_offset_on_dilated_shard(0); + bool all_same = true; + for (int64 shard_ordinal = 1; shard_ordinal < shard_count; + ++shard_ordinal) { + int64 start = + get_first_valid_element_offset_on_dilated_shard(shard_ordinal); + if (start != swd->padding_low()) { + all_same = false; + } + max_pad_low = std::max(max_pad_low, start); + } + if (!all_same) { + auto start_on_padded_input = + start_on_padded_calculations[i].Calculate(partition_ordinals[i], + state_.b); + // We will calculate + // max_pad_low - (first_window - required_first_window) + // which equals + // required_first_window - (first_window - max_pad_low) + auto first_window_minus_max_pad_low = + MultiplyAddDivideOffsetCalculation( + wd.base_dilation(), swd->padding_low() - max_pad_low, 1) + .Calculate(start_on_padded_input, state_.b); + auto required_first_window = + MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0, + 1) + .Calculate(partition_ordinals[i], state_.b); + dynamic_slice_offset_on_output[i] = + state_.b->AddInstruction(HloInstruction::CreateBinary( + required_first_window->shape(), HloOpcode::kSubtract, + required_first_window, first_window_minus_max_pad_low)); + } + swd->set_padding_low(max_pad_low); + } else { + CHECK_EQ( + (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation(), 0) + << "General base dilation not yet implemented."; + // padding_low on all shards should equal the initially assigned + // swd->padding_low(), i.e., the padding_low() on the original window. + } + } + } + + // Returns the output dynamic slice offset when needed, and absl::nullopt + // otherwise. + auto get_dynamic_slice_offset_on_output_if_needed = + [&]() -> absl::optional> { + if (absl::c_all_of( + dynamic_slice_offset_on_output, + [](HloInstruction* offset) { return offset == nullptr; })) { + return absl::nullopt; + } + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + for (int64 i = 0; i < dynamic_slice_offset_on_output.size(); ++i) { + if (dynamic_slice_offset_on_output[i] == nullptr) { + dynamic_slice_offset_on_output[i] = zero; + } + } + return dynamic_slice_offset_on_output; + }; + + // If the currrent HLO is replicated, pad then slice. + if (sharding().IsReplicated()) { + PaddingConfig padding_config; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + // Do not pad non-partitioned dimensions. + if (target.tile_assignment().dim(i) == 1) { + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + continue; + } + padding_config_dim->set_edge_padding_low(explicit_left_padding[i]); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + explicit_left_padding[i] - + base_shape_.dimensions(i)); + } + auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_) + ? hlo_ + : state_.b->AddInstruction(HloInstruction::CreatePad( + padded_shape, hlo_, pad_value, padding_config)); + auto sharded_input = + state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, offsets_on_padded_shape, + shard_shape.dimensions())); + return update_cache(WindowedInputShardReturnValue{ + sharded_input, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); + } + + if (target != sharding()) { + return Replicate().ReshardAsWindowedInput(window, target, pad_value); + } + + // Halo exchange. + HloInstruction* visiting_hlo = hlo_; + auto original_shard_shape = MakePartitionedShape(base_shape_, target); + + std::vector left_halo_size_functions(base_shape_.rank()); + std::vector right_halo_size_functions(base_shape_.rank()); + // TODO(yuanzx): We are concatenating on each sharded dimension one at time, + // and in the second dimension (and beyond) we create halos by slicing the + // concat in the previous dimension, which is not optimal. We should generate + // halos only concating slices, instead of slicing concats. + for (int dim = 0; dim < base_shape_.rank(); ++dim) { + int64 shard_count = target.tile_assignment().dim(dim); + if (shard_count == 1) { + continue; + } + int64 input_shard_size = + CeilOfRatio(base_shape_.dimensions(dim), shard_count); + + // Left halo. The size of the halo is derived by subtracting the first read + // element offset of the i'th partition from the limit of the (i-1)'th + // partition. + MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded( + input_shard_size, explicit_left_padding[dim], 1); + left_halo_size_functions[dim] = + shard_limit_of_previous_on_padded - start_on_padded_calculations[dim]; + + // Right halo. + MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded( + input_shard_size, input_shard_size + explicit_left_padding[dim], 1); + right_halo_size_functions[dim] = + limit_on_padded_calculations[dim] - shard_start_of_next_on_padded; + + auto resharded = ExchangeHaloAndGetValidData( + visiting_hlo, base_shape_, left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding[dim], + padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target, + offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim], + state_.collective_ops_creator, state_.next_channel_id, state_.b, + mask_invalid_region); + if (!resharded) { + VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo " + "is beyond the neighbor."; + return Replicate().ReshardAsWindowedInput(window, target, pad_value); + } + visiting_hlo = *resharded; + } + return update_cache(WindowedInputShardReturnValue{ + visiting_hlo, shard_window, + get_dynamic_slice_offset_on_output_if_needed()}); +} + +PartitionedHlo PartitionedHlo::Replicate() { + const HloSharding& sharding = hlo_->sharding(); + const Shape& shape = hlo_->shape(); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + if (sharding.IsReplicated()) { + return *this; + } + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + for (auto& entry : cache) { + if (entry.first.IsReplicated()) { + return entry.second; + } + } + auto update_cache = [&](PartitionedHlo resharded) { + state_.reshard_cache->per_hlo_cache[resharded.hlo()] + .reshard_cache.emplace_back(sharding, *this); + cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); + return cache.back().second; + }; + // 'Single Device' to 'Repliated'. + if (sharding.IsTileMaximal()) { + return update_cache(Broadcast()); + } + + // 'Tiled' to 'Replicated'. + Shape padded_base_shape = shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); + auto dus = state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + padded_base_shape, zero_bcast, hlo_, + MakePartitionOffsets(padded_base_shape, sharding, state_.partition_id, + state_.b))); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto all_reduce = + state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, dus, reduction, NewChannel()); + HloInstruction* result = all_reduce; + if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) { + std::vector start_indices(shape.rank(), 0); + std::vector strides(shape.rank(), 1); + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + base_shape_, result, start_indices, base_shape_.dimensions(), strides)); + } + result->set_sharding(HloSharding::Replicate()); + return update_cache(PartitionedHlo(result, base_shape_, state_)); +} + +PartitionedHlo PartitionedHlo::Broadcast() const { + const Shape& shape = hlo_->shape(); + const HloSharding& sharding = hlo_->sharding(); + CHECK(sharding.HasUniqueDevice()); + CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); + + auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(sharding.GetUniqueDevice()))); + Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED); + auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast( + bcast_shape, + state_.b->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id, + ComparisonDirection::kEq)), + {})); + + auto zero = state_.b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(shape, zero, {})); + auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast)); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); + + auto result = state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, operand, reduction, NewChannel()); + result->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithAllToAll( + const HloSharding& target) const { + int64 partition_count = sharding().tile_assignment().num_elements(); + absl::optional input_partition_dim = UniqueTiledDim(sharding()); + absl::optional output_partition_dim = UniqueTiledDim(target); + CHECK(input_partition_dim.has_value()); + CHECK(output_partition_dim.has_value()); + + // If the device order is different in the target, fix the order with + // ReshardWithCollectivePermute. + auto input_tile_fixed_device_order = target.tile_assignment(); + input_tile_fixed_device_order.Reshape( + sharding().tile_assignment().dimensions()); + auto input_sharding_fixed_device_order = + HloSharding::Tile(input_tile_fixed_device_order); + if (input_sharding_fixed_device_order != sharding()) { + auto fixed_order = + ReshardWithCollectivePermute(input_sharding_fixed_device_order); + return fixed_order.ReshardWithAllToAll(target); + } + + auto padded_hlo = + PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + + // The order of ids in the group must follow the target sharding. + std::vector groups(1); + for (int64 device : target.tile_assignment()) { + groups[0].add_replica_ids(device); + } + + HloInstruction* result = nullptr; + + // Split along the split dimension (output_partition_dim) of the all-to-all + // output. + std::vector dimensions; + for (int64 i = 0; i < base_shape_.rank(); ++i) { + if (i == *output_partition_dim) { + dimensions.push_back(partition_count); + dimensions.push_back(padded_hlo->shape().dimensions(i) / partition_count); + } else { + dimensions.push_back(padded_hlo->shape().dimensions(i)); + } + } + auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(base_shape_.element_type(), dimensions), + padded_hlo)); + // After the reshape, it is guaranteed to have at least 3 dimensions. + auto all_to_all = + state_.collective_ops_creator.create_cross_partition_all_to_all( + state_.b, {reshape}, groups, (*state_.next_channel_id)++, + output_partition_dim); + + // Reorder the split dimension of the reshape to be located in front of the + // input partition dimension, so the two dimensions can be combined. + int64 new_input_partition_dim = (*output_partition_dim < *input_partition_dim) + ? *input_partition_dim + 1 + : *input_partition_dim; + std::vector permutation; + for (int64 i = 0; i < all_to_all->shape().rank(); ++i) { + if (i == *output_partition_dim) { + continue; + } + if (i == new_input_partition_dim) { + permutation.push_back(*output_partition_dim); + } + permutation.push_back(i); + } + auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(all_to_all->shape(), permutation) + .ValueOrDie(), + all_to_all, permutation)); + + // Combine the split dimension and the input partition dimension. + auto new_shape = ShapeInference::InferAllToAllShape( + padded_hlo->shape(), *output_partition_dim, + *input_partition_dim, partition_count) + .ValueOrDie(); + result = state_.b->AddInstruction( + HloInstruction::CreateReshape(new_shape, transpose)); + + const Shape result_shape = MakePartitionedShape(base_shape_, target); + if (result_shape != result->shape()) { + result = state_.b->AddInstruction(HloInstruction::CreateSlice( + result_shape, result, std::vector(result_shape.rank(), 0), + result_shape.dimensions(), std::vector(result_shape.rank(), 1))); + } + result->set_sharding(target); + return PartitionedHlo(result, base_shape_, state_); +} + +PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( + const HloSharding& target) const { + CHECK(CanReshardWithCollectivePermute(sharding(), target)); + std::vector> src_dst_pairs; + sharding().tile_assignment().Each( + [&](absl::Span indices, int64 src_device) { + int64 dst_device = target.tile_assignment()(indices); + if (dst_device != src_device) { + src_dst_pairs.emplace_back(src_device, dst_device); + } + }); + auto cp = + state_.collective_ops_creator.create_cross_partition_collective_permute( + state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++); + cp->set_sharding(target); + return PartitionedHlo(cp, base_shape_, state_); +} + +SpmdPartitioningVisitor::SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options, + SpmdPartitioner* partitioner) + : changed_(false), + module_(computation->parent()), + num_partitions_(num_partitions), + num_replicas_(num_replicas), + collective_ops_creator_(collective_ops_creator), + next_channel_id_(next_channel_id), + b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)), + partition_id_(collective_ops_creator_.create_partition_id(&b_)), + logger_(logger), + options_(std::move(options)), + partitioner_(partitioner) {} + +Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) { + if (hlo->HasSideEffect()) { + return Unimplemented("Side-effect ops cannot be replicated: %s", + hlo->ToString()); + } + + if (hlo->IsElementwise() && hlo->operand_count() > 0) { + return HandleElementwise(hlo); + } + + if (!hlo->sharding().IsTileMaximal()) { + VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):" + << hlo->ToString(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + VLOG(1) << " operand " << i + << " sharding:" << hlo->operand(i)->sharding().ToString(); + } + } + + // If the instruction cannot be partitioned, replicate the instruction unless + // the instruction has side-effect. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(HloSharding::Replicate()).hlo()); + } + auto clone = + b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::Replicate()); + clone->set_metadata(hlo->metadata()); + SetPartitionedHlo(hlo, + PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding())); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) { + visiting_hlo_ = hlo; + b_.set_visiting_hlo(hlo); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) { + logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(), + b_.derived_instructions(hlo)); + visiting_hlo_ = nullptr; + b_.set_visiting_hlo(nullptr); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + const int64 dimension = hlo->concatenate_dimension(); + if (sharding.tile_assignment().dim(dimension) == 1) { + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back( + GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, new_operands)); + }); + return Status::OK(); + } + + // If the concatenate dimension is along one of the partitioned dimensions, + // allocate the full output shape, each partition updates its owned region, + // all-reduce across partitions, and then slice its output region. + + // We currently don't support subgroup all-reduce along partitions, so more + // than 1 partitioned dimensions is not supported. + if (sharding.tile_assignment().dim(dimension) != num_partitions_) { + return DefaultAction(hlo); + } + + // temp_output_shape is the output shape where the concatenate dimension + // is changed to the full (and padded to shard count) dimension size. + auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding); + temp_output_shape.set_dimensions( + dimension, temp_output_shape.dimensions(dimension) * + sharding.tile_assignment().dim(dimension)); + auto temp_output = CreateZero(temp_output_shape, &b_); + + // Offset of each operand along the concatenate dimension. + int64 offset = 0; + for (HloInstruction* operand : hlo->operands()) { + auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo(); + std::vector start_indices( + hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(S32)))); + start_indices[dimension] = + MultiplyAddDivideOffsetCalculation( + spmd_operand->shape().dimensions(dimension), offset, 1) + .Calculate(MakeTiledPartitionOrdinals(sharding, partition_id_, + &b_)[dimension], + &b_); + temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + temp_output_shape, temp_output, spmd_operand, start_indices)); + offset += operand->shape().dimensions(dimension); + } + auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + SetPartitionedHlo(hlo, [&] { + auto start_indices = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + start_indices[dimension] = MultiplyAddDivideOffsetCalculation( + shard_shape.dimensions(dimension), 0, 1) + .Calculate(start_indices[dimension], &b_); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, all_reduce, start_indices, shard_shape.dimensions())); + }); + + return Status::OK(); +} + +// If partitioning in the operand only happens in dimensions in passthrough +// dimensions (offset dimensions in the gather output (or scatter update) that +// have the same size as the operand), returns the corresponding output (or +// update) sharding by passing through the input sharding. +absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( + const PartitionedHlo& operand, const Shape& update_or_gather_shape, + absl::Span collapsed_or_inserted_dims, + absl::Span index_map, + absl::Span offset_or_window_dims, + absl::Span slice_size) { + if (operand.sharding().IsTileMaximal()) { + return operand.sharding(); + } + std::vector passthrough_tile(update_or_gather_shape.rank(), 1); + int64 collapsed = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + int64 dim_partitions = operand.sharding().tile_assignment().dim(i); + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(index_map, i)) { + if (dim_partitions > 1) { + return absl::nullopt; + } + collapsed++; + continue; + } + if (slice_size[i] != operand.base_shape().dimensions(i) && + dim_partitions > 1) { + return absl::nullopt; + } + int64 offset_dim = offset_or_window_dims[i - collapsed]; + if (i - collapsed > 0 && + offset_dim < offset_or_window_dims[i - collapsed - 1]) { + // Output offsets are transposed, we do not support this case. + return absl::nullopt; + } + passthrough_tile[offset_dim] = dim_partitions; + } + Array tile_assignment = operand.sharding().tile_assignment(); + tile_assignment.Reshape(passthrough_tile); + return HloSharding::Tile(tile_assignment); +} + +// Returns whether partitioning in the operand only happens in dimensions with +// gather/scatter slice size 1. +bool GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + const PartitionedHlo& operand, absl::Span index_map, + absl::Span slice_size, int64 num_partitions) { + if (operand.sharding().IsTileMaximal()) { + return false; + } + int64 trivial_slice_dims_partitions = 1; + for (int64 dim : index_map) { + if (slice_size[dim] == 1) { + trivial_slice_dims_partitions *= + operand.sharding().tile_assignment().dim(dim); + } + } + return trivial_slice_dims_partitions == num_partitions; +} + +// Returns the min and max for the indices (replicated) in a scatter/gather +// which has the operand partitioned on trivial slice dimensions (slice size 1). +std::pair +IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + const PartitionedHlo& operand, const PartitionedHlo& replicated_indices, + HloInstruction* partition_id, absl::Span index_map, + int64 index_vector_dim, SpmdBuilder* b) { + auto operand_offsets = MakePartitionOffsets( + operand.base_shape(), operand.sharding(), partition_id, b); + // Find the per-dimension index bounds. + std::vector min_indices; + std::vector max_indices; + for (int64 i = 0; i < index_map.size(); ++i) { + int64 dim = index_map[i]; + int64 partitions = operand.sharding().tile_assignment().dim(dim); + if (partitions == 1) { + min_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), 0, b)); + max_indices.push_back(CreateR0WithType( + replicated_indices.base_shape().element_type(), + operand.base_shape().dimensions(dim), b)); + continue; + } + auto offset = operand_offsets[dim]; + if (offset->shape().element_type() != + replicated_indices.base_shape().element_type()) { + offset = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(replicated_indices.base_shape().element_type(), + {}), + offset)); + } + min_indices.push_back(offset); + auto partition_size_minus_1 = + CreateR0WithType(replicated_indices.base_shape().element_type(), + operand.hlo()->shape().dimensions(dim) - 1, b); + max_indices.push_back(b->AddInstruction(HloInstruction::CreateBinary( + offset->shape(), HloOpcode::kAdd, offset, partition_size_minus_1))); + } + // Broadcast the index bounds to the same shape as the indices. + HloInstruction* broadcast_min; + HloInstruction* broadcast_max; + if (index_vector_dim < replicated_indices.base_shape().rank()) { + // The index vector is an R1, we need to reshape individual bounds to + // [1], and concat them if there are more than one. + for (int64 i = 0; i < min_indices.size(); ++i) { + min_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(min_indices[i]->shape().element_type(), {1}), + min_indices[i])); + max_indices[i] = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(max_indices[i]->shape().element_type(), {1}), + max_indices[i])); + } + int64 slice_dims = max_indices.size(); + if (slice_dims > 1) { + min_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(min_indices[0]->shape().element_type(), + {slice_dims}), + min_indices, 0)); + max_indices[0] = b->AddInstruction(HloInstruction::CreateConcatenate( + min_indices[0]->shape(), max_indices, 0)); + } + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {index_vector_dim})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {index_vector_dim})); + } else { + CHECK_EQ(max_indices.size(), 1); + broadcast_min = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), min_indices[0], {})); + broadcast_max = b->AddInstruction(HloInstruction::CreateBroadcast( + replicated_indices.base_shape(), max_indices[0], {})); + } + return {broadcast_min, broadcast_max}; +} + +Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { + auto scatter = Cast(hlo); + auto dnums = scatter->scatter_dimension_numbers(); + auto operand = GetPartitionedHlo(scatter->operand(0)); + auto indices = GetPartitionedHlo(scatter->operand(1)); + auto updates = GetPartitionedHlo(scatter->operand(2)); + std::vector slice_size(operand.base_shape().rank(), 1); + int64 num_update_window_dims = 0; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + continue; + } + slice_size[i] = updates.base_shape().dimensions( + dnums.update_window_dims(num_update_window_dims++)); + } + std::vector inserted_window_dims(dnums.inserted_window_dims().begin(), + dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, updates.base_shape(), inserted_window_dims, + scatter_dims_to_operand_dims, update_window_dims, slice_size); + // Handle pass through cases if we can use compatible sharding for update. + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(*maybe_passthrough); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), indices.hlo(), updates.hlo(), + scatter->to_apply(), dnums, scatter->indices_are_sorted(), + scatter->unique_indices())); + pscatter->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, scatter_dims_to_operand_dims, slice_size, + num_partitions_) && + ShapeUtil::ByteSizeOf(updates.base_shape()) < + ShapeUtil::ByteSizeOf(scatter->shape())) { + // Operand is sharded on trivial slice dims (update slice size 1). We can + // adjust the indices on each partition by subtracting the offsets. Then + // we execute a scatter on full updated indices, and out-of-bound accesses + // will have no effect on the result as guaranteed by the scatter + // semantics. + indices = indices.Reshard(HloSharding::Replicate()); + updates = updates.Reshard(HloSharding::Replicate()); + HloInstruction* indices_min; + HloInstruction* indices_max_unused; + std::tie(indices_min, indices_max_unused) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, scatter_dims_to_operand_dims, + dnums.index_vector_dim(), &b_); + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + indices_min)); + auto pscatter = b_.AddInstruction(HloInstruction::CreateScatter( + operand.hlo()->shape(), operand.hlo(), adjusted_indices, + updates.hlo(), scatter->to_apply(), dnums, + scatter->indices_are_sorted(), scatter->unique_indices())); + pscatter->set_sharding(operand.sharding()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pscatter, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding); + + // Create a window config to represent the slice. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(hlo->slice_strides(i)); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_padding_low(-hlo->slice_starts(i)); + dim->set_padding_high(hlo->slice_limits(i) - + hlo->operand(0)->shape().dimensions(i)); + dim->set_base_dilation(1); + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + const Shape& operand_shape = reshard_operand->sharded_input->shape(); + + std::vector start_indices = hlo->slice_starts(); + std::vector limit_indices = hlo->slice_limits(); + std::vector strides = hlo->slice_strides(); + bool need_slice = false; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + auto dim = reshard_operand->shard_window.dimensions(i); + start_indices[i] = -dim.padding_low(); + limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high(); + if (start_indices[i] != 0 || strides[i] != 1 || + limit_indices[i] != operand_shape.dimensions(i)) { + need_slice = true; + } + } + + SetPartitionedHlo(hlo, [&] { + if (need_slice) { + auto shard_shape = MakePartitionedShape(hlo->shape(), sharding); + return b_.AddInstruction(HloInstruction::CreateSlice( + shard_shape, reshard_operand->sharded_input, start_indices, + limit_indices, strides)); + } + return reshard_operand->sharded_input; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) { + HloSharding sharding = hlo->sharding(); + if (hlo->shape().IsTuple()) { + // Check that all elements are sharded in the same way. + if (hlo->shape().tuple_shapes_size() == 0) { + return DefaultAction(hlo); + } + sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) { + return DefaultAction(hlo); + } + } + } + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 dim : hlo->dimensions()) { + if (sharding.tile_assignment().dim(dim) > 1) { + return DefaultAction(hlo); + } + } + // Reshard operands to the same as the output. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + } + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "SPMDFullToShardShape") { + // This op switches from auto partitioning to manual partitioning. + auto input_partitioned = GetPartitionedHlo(hlo->operand(0)); + if (!EvenlyPartitions(hlo->shape(), input_partitioned.sharding())) { + input_partitioned = input_partitioned.PadWithValue( + CreateR0WithType(hlo->shape().element_type(), 0, &b_)); + } + auto input = input_partitioned.hlo(); + CHECK(hlo->sharding().IsReplicated()); + CHECK(ShapeUtil::Compatible(input->shape(), hlo->shape())); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() == "SPMDShardToFullShape") { + // This op switches from manual partitioning to auto partitioning. + auto input = GetPartitionedHlo(hlo->operand(0)).hlo(); + CHECK(input->sharding().IsReplicated()); + auto copy = b_.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input)); + CHECK(ShapeUtil::Compatible( + copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding()))); + SetPartitionedHlo(hlo, [&] { return copy; }); + return Status::OK(); + } + if (hlo->custom_call_target() != "TopK") { + return DefaultAction(hlo); + } + + if (!hlo->operand(0)->has_sharding()) { + return DefaultAction(hlo); + } + + const HloSharding& sharding = hlo->operand(0)->sharding(); + if (sharding.IsTileMaximal() || sharding.IsReplicated()) { + return DefaultAction(hlo); + } + + const int64 sort_dim = 1; + const int64 shard_count = sharding.tile_assignment().dim(sort_dim); + + if (shard_count <= 1) { + return DefaultAction(hlo); + } + + const int64 input_size = hlo->operand(0)->shape().dimensions(sort_dim); + const int64 batch_size = hlo->shape().tuple_shapes(0).dimensions(0); + const int64 k = hlo->shape().tuple_shapes(0).dimensions(sort_dim); + const int64 per_partition_size = CeilOfRatio(input_size, shard_count); + + if (k >= per_partition_size) { + return DefaultAction(hlo); + } + + auto input = hlo->operand(0); + const auto element_type = input->shape().element_type(); + + // Pad input with minimal value. + auto min_value = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MinValue(element_type))); + // TODO(wangtao): add test to see if -NaN < -Inf in BF16. + if (element_type == F32) { + auto float_pad_value = std::numeric_limits::quiet_NaN(); + min_value = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-float_pad_value))); + } + auto partitioned_input = GetPartitionedHlo(input).PadWithValue(min_value); + + // Each partition needs to do TopK separately, thus the base shape + // becomes [batch_size, k * shard_count]. + const Shape replicated_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(hlo->operand(0)->shape().element_type(), + {batch_size, k * shard_count}), + ShapeUtil::MakeShape(S32, {batch_size, k * shard_count})}); + auto custom_call_sharding = + sharding.GetTupleSharding(replicated_shape).ValueOrDie(); + auto shard_shape = + MakePartitionedShape(replicated_shape, custom_call_sharding); + auto topk = b_.AddInstruction( + hlo->CloneWithNewOperands(shard_shape, {partitioned_input.hlo()})); + topk->set_sharding(custom_call_sharding); + // Partition customcall. + PartitionedHlo partitioned_topk(topk, replicated_shape, + MakePartitioningState()); + topk = partitioned_topk.hlo(); + + // Get value from TopK. + HloInstruction* value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(0), topk, 0)); + value_gte->set_sharding(sharding); + // Partition GetTupleElement of value. + PartitionedHlo value_partitioned_gte( + value_gte, partitioned_topk.base_shape().tuple_shapes(0), + MakePartitioningState()); + // Reshard value to be replicated. + auto replicated_value_gte = + value_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Get index from TopK. + HloInstruction* index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + auto partition_id_s32 = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(S32, partition_id_->shape().dimensions()), + partition_id_)); + // Add per partition offset to index, index returned from CustomCall always + // starts from 0. + auto index_offset = b_.AddInstruction(HloInstruction::CreateBroadcast( + index_gte->shape(), + b_.AddInstruction(HloInstruction::CreateBinary( + partition_id_s32->shape(), HloOpcode::kMultiply, partition_id_s32, + b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(per_partition_size))))), + {})); + index_gte = b_.AddInstruction(HloInstruction::CreateBinary( + index_offset->shape(), HloOpcode::kAdd, index_gte, index_offset)); + index_gte->set_sharding(sharding); + // Parttion GetTupleElement of index. + PartitionedHlo index_partitioned_gte( + index_gte, partitioned_topk.base_shape().tuple_shapes(1), + MakePartitioningState()); + // Reshard index to be replicated. + auto replicated_index_gte = + index_partitioned_gte.Reshard(HloSharding::Replicate()).hlo(); + + // Creates replicated sort to do TopK, the input is value and index pairs + // from all the partitions. The reason to use Sort instead of CustomCall TopK + // is CustomCall only takes value as input. There will be an extra Gather + // to get the correct index if CustomCall is used here. + + // Create comparator for the sort. + XlaBuilder b("Sort.Compare"); + XlaComputation comparator = CreateScalarComparisonComputation( + "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt}, + &b); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(comparator.proto(), config)); + HloCloneContext context(module_); + auto compare_computation = + module_->DeepCloneComputation(new_module->entry_computation(), &context); + auto sort = b_.AddInstruction(HloInstruction::CreateSort( + replicated_shape, sort_dim, {replicated_value_gte, replicated_index_gte}, + compare_computation, true)); + sort->set_sharding( + HloSharding::Replicate().GetTupleSharding(sort->shape()).ValueOrDie()); + PartitionedHlo replicated_sort(sort, replicated_shape, + MakePartitioningState()); + + // Slice value and index from top-k for output. + HloInstruction* sort_value_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(0), replicated_sort.hlo(), + 0)); + HloInstruction* sort_index_gte = + b_.AddInstruction(HloInstruction::CreateGetTupleElement( + replicated_sort.hlo()->shape().tuple_shapes(1), replicated_sort.hlo(), + 1)); + const Shape& hlo_shape = sort_value_gte->shape(); + auto hlo_dims = hlo_shape.dimensions(); + std::vector start_indices(hlo_shape.dimensions_size(), 0); + std::vector limit_indices(hlo_dims.begin(), hlo_dims.end()); + std::vector strides(hlo_shape.dimensions_size(), sort_dim); + limit_indices[sort_dim] = k; + auto output_shape = hlo_shape; + output_shape.set_dimensions(sort_dim, k); + // Slice value from final sort. + HloInstruction* slice_sort_value = + b_.AddInstruction(HloInstruction::CreateSlice( + output_shape, sort_value_gte, start_indices, limit_indices, strides)); + // Slice index from final sort. + auto index_output_shape = sort_index_gte->shape(); + index_output_shape.set_dimensions(sort_dim, k); + HloInstruction* slice_index_value = b_.AddInstruction( + HloInstruction::CreateSlice(index_output_shape, sort_index_gte, + start_indices, limit_indices, strides)); + auto create_tuple = b_.AddInstruction( + HloInstruction::CreateTuple({slice_sort_value, slice_index_value})); + create_tuple->set_sharding(HloSharding::Replicate()); + + SetPartitionedHlo(hlo, PartitionedHlo(create_tuple, create_tuple->shape(), + MakePartitioningState()) + .Reshard(hlo->sharding())); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + std::vector inverse_dimensions(hlo->shape().rank()); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + inverse_dimensions[hlo->dimensions(i)] = i; + } + auto desired_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions); + + auto operand = GetPartitionedHlo(hlo->operand(0)) + .Reshard(desired_operand_sharding) + .hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto operand = GetPartitionedHlo(hlo->operand(0)); + // The output shape is the source and the operand shape is the target to get + // the aligned sharding for the operand. + auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding( + hlo->shape(), hlo->operand(0)->shape(), hlo->sharding()); + if (desired_operand_sharding.has_value()) { + auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo(); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo})); + }); + return Status::OK(); + } + + // Try use halo exchange for certain split-dim/merge-dims cases. + // ReshapeSharding failed in these cases probably due to uneven partitioning, + // where halo exchange could help. Specifically we check the following + // conditions to detect supported cases: + // 1) Both input and output are partitioned on one dimension. + // 2) The combined size of dimensions before the partitioned dimension are the + // same on input and output. This means we don't need to consider the major + // dimensions. + // 3) Let A = the input size on the partitioned dimension, and + // B = the output size on the partitioned dimension; then + // either A % B == 0 (split dim) or B % A == 0 (merge dims). + auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding()); + auto maybe_output_sharded_dim = UniqueTiledDim(sharding); + if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) { + return DefaultAction(hlo); + } + int64 input_sharded_dim = *maybe_input_sharded_dim; + int64 output_sharded_dim = *maybe_output_sharded_dim; + // Check that the major dims before the sharded dim have the same total size + // for input and output. + int64 input_major_dims_size = 1; + for (int64 i = 0; i < input_sharded_dim; ++i) { + input_major_dims_size *= operand.base_shape().dimensions(i); + } + int64 output_major_dims_size = 1; + for (int64 i = 0; i < output_sharded_dim; ++i) { + output_major_dims_size *= hlo->shape().dimensions(i); + } + if (input_major_dims_size != output_major_dims_size) { + return DefaultAction(hlo); + } + // Fix potential device ordering mismatch in tile assignment. + Array new_input_tile_assignment = sharding.tile_assignment(); + new_input_tile_assignment.Reshape( + operand.sharding().tile_assignment().dimensions()); + operand = operand.Reshard(HloSharding::Tile(new_input_tile_assignment)); + + int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim); + int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim); + auto input_shard_shape = + MakePartitionedShape(operand.base_shape(), operand.sharding()); + auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding); + if (input_dim_size % output_dim_size == 0) { + // Split dim. + int64 split_factor = input_dim_size / output_dim_size; + int64 output_shard_size = output_shard_shape.dimensions(output_sharded_dim); + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == input_sharded_dim) { + dim->set_padding_high(output_shard_size * split_factor * + num_partitions_ - + input_dim_size); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_operand = operand.ReshardAsWindowedInput( + window, operand.sharding(), + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_operand.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_operand->sharded_input->shape().dimensions(input_sharded_dim), + output_shard_size * split_factor); + SetPartitionedHlo(hlo, [&] { + // Do a local reshape. + return b_.AddInstruction(HloInstruction::CreateReshape( + output_shard_shape, reshard_operand->sharded_input)); + }); + return Status::OK(); + } else if (output_dim_size % input_dim_size == 0) { + // Merge dims. + int64 merge_factor = output_dim_size / input_dim_size; + // First reshape locally. (The sharded dimension could include padded data.) + auto tmp_shard_shape = output_shard_shape; + tmp_shard_shape.set_dimensions( + output_sharded_dim, + input_shard_shape.dimensions(input_sharded_dim) * merge_factor); + auto tmp_reshape = b_.AddInstruction( + HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo())); + tmp_reshape->set_metadata(hlo->metadata()); + tmp_reshape->set_sharding(hlo->sharding()); + auto tmp_full_shape = tmp_shard_shape; + tmp_full_shape.set_dimensions( + output_sharded_dim, + tmp_shard_shape.dimensions(output_sharded_dim) * num_partitions_); + auto tmp_output = + PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState()); + + // Use halo exchange to fix misaligned data. + Window window; + for (int64 i = 0; i < tmp_shard_shape.rank(); ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_window_reversal(false); + dim->set_base_dilation(1); + dim->set_padding_low(0); + if (i == output_sharded_dim) { + dim->set_padding_high(output_dim_size - + tmp_shard_shape.dimensions(output_sharded_dim) * + num_partitions_); + } else { + dim->set_padding_high(0); + } + } + + auto reshard_output = tmp_output.ReshardAsWindowedInput( + window, sharding, + CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_), + /*mask_invalid_region=*/false); + if (!reshard_output.has_value()) { + return DefaultAction(hlo); + } + TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value()); + CHECK_EQ( + reshard_output->sharded_input->shape().dimensions(output_sharded_dim), + output_shard_shape.dimensions(output_sharded_dim)); + SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + if (sharding.IsTileMaximal()) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&] { + int64 dimension = Cast(hlo)->iota_dimension(); + auto iota = b_.AddInstruction(HloInstruction::CreateIota( + MakePartitionedShape(hlo->shape(), sharding), dimension)); + + if (sharding.tile_assignment().dim(dimension) > 1) { + auto partition_ordinals = + MakeTiledPartitionOrdinals(sharding, partition_id_, &b_); + auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(iota->shape().dimensions(dimension)))); + auto offset = b_.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, + partition_ordinals[dimension], multiplier)); + if (iota->shape().element_type() != S32) { + offset = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset)); + } + auto broadcast = b_.AddInstruction( + HloInstruction::CreateBroadcast(iota->shape(), offset, {})); + return b_.AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, broadcast)); + } + + return iota; + }); + + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + int64 device = hlo->sharding().GetUniqueDevice(); + const HloSharding sharding = HloSharding::AssignDevice(device); + + std::vector operands; + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); + operand_shapes.push_back(operand->shape()); + } + auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands)); + auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes); + + auto on_device = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(device))); + auto pred = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), partition_id_, on_device, + ComparisonDirection::kEq)); + + SpmdBuilder true_b("true_computation", visiting_hlo_); + HloComputation* true_computation; + { + auto param = true_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "true_branch_param")); + std::vector new_operands; + for (int64 i = 0; i < operands.size(); ++i) { + new_operands.push_back(true_b.AddInstruction( + HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i))); + } + auto root = true_b.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + true_computation = module_->AddEmbeddedComputation(true_b.Build(root)); + } + + SpmdBuilder false_b("false_computation", visiting_hlo_); + HloComputation* false_computation; + { + false_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, operand_shape, "false_branch_param")); + auto root = CreateZero(hlo->shape(), &false_b); + false_computation = module_->AddEmbeddedComputation(false_b.Build(root)); + } + + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + hlo->shape(), pred, operand, true_computation, operand, + false_computation)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) { + if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) { + return HandleElementwise(hlo); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + auto& operand = GetPartitionedHlo(hlo->operand(0)); + + // Tiled output. + std::vector wanted_input_tile_size(operand.base_shape().rank()); + std::vector sharded_new_dims; + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + wanted_input_tile_size[i] = + hlo->sharding().tile_assignment().dim(hlo->dimensions(i)); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i) && + hlo->sharding().tile_assignment().dim(i) > 1) { + sharded_new_dims.push_back(i); + } + } + if (sharded_new_dims.empty()) { + // The new dimensions are replicated, so that we can do the adjustment on + // the input. + Array wanted_input_tile_assignment(wanted_input_tile_size); + wanted_input_tile_assignment.Each( + [&](absl::Span indices, int64* val) { + std::vector indices_in_broadcast(hlo->shape().rank(), 0); + for (int64 i = 0; i < operand.base_shape().rank(); ++i) { + indices_in_broadcast[hlo->dimensions(i)] = indices[i]; + } + *val = hlo->sharding().tile_assignment()(indices_in_broadcast); + }); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(hlo->CloneWithNewOperands( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + {operand.Reshard(HloSharding::Tile(wanted_input_tile_assignment)) + .hlo()})); + }); + } else { + auto input = operand.Reshard(HloSharding::Replicate()).hlo(); + // We pad and shard the input first, then broadcast to the final shard + // shape. + auto output_offsets = + MakePartitionOffsets(hlo->shape(), hlo->sharding(), partition_id_, &b_); + std::vector input_offsets(operand.base_shape().rank()); + auto output_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto input_shard_shape = input->shape(); + auto padded_input_shape = input->shape(); + for (int64 i = 0; i < input_offsets.size(); ++i) { + input_offsets[i] = output_offsets[hlo->dimensions(i)]; + input_shard_shape.set_dimensions( + i, output_shard_shape.dimensions(hlo->dimensions(i))); + padded_input_shape.set_dimensions( + i, hlo->sharding().tile_assignment().dim(hlo->dimensions(i)) * + input_shard_shape.dimensions(i)); + } + auto padded_input = PadToShape(input, padded_input_shape, &b_); + auto input_shard = + ShapeUtil::Compatible(input_shard_shape, padded_input->shape()) + ? padded_input + : b_.AddInstruction(HloInstruction::CreateDynamicSlice( + input_shard_shape, padded_input, input_offsets, + input_shard_shape.dimensions())); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(output_shard_shape, {input_shard})); + }); + } + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) { + const Literal& literal = hlo->literal(); + if (literal.shape().IsTuple() || + (!hlo->sharding().IsTileMaximal() && + (!EvenlyPartitions(hlo->shape(), hlo->sharding()) || + !literal.IsAllFirst()))) { + return DefaultAction(hlo); + } + + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + std::vector start_indices(hlo->shape().rank(), 0); + auto constant = b_.AddInstruction(HloInstruction::CreateConstant( + literal.Slice(start_indices, shard_shape.dimensions()))); + *constant->mutable_shape() = shard_shape; + return constant; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) || + !hlo->operand(i + 1)->IsConstant() || + !hlo->operand(i + 1)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + partitioned_shape, new_input, new_indices, + partitioned_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) != 1 && + (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i) || + !hlo->operand(i + 2)->IsConstant() || + !hlo->operand(i + 2)->literal().IsZero({}))) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); + } + } + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + auto new_update = + GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo(); + for (int64 i = 0; i < new_indices.size(); ++i) { + // Replicate the indices. + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2)) + .Reshard(HloSharding::Replicate()) + .hlo(); + } + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + partitioned_shape, new_input, new_update, new_indices)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { + auto gather = Cast(hlo); + const auto& dnums = gather->gather_dimension_numbers(); + auto operand = GetPartitionedHlo(gather->operand(0)); + auto indices = GetPartitionedHlo(gather->operand(1)); + std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), + dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + if (!operand.sharding().IsTileMaximal()) { + auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( + operand, gather->shape(), collapsed_slice_dims, start_index_map, + offset_dims, gather->gather_slice_sizes()); + if (maybe_passthrough.has_value()) { + indices = indices.Reshard(HloSharding::Replicate()); + auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough); + std::vector pslice_sizes(gather->gather_slice_sizes().begin(), + gather->gather_slice_sizes().end()); + for (int64 i = 0; i < pslice_sizes.size(); ++i) { + if (operand.sharding().tile_assignment().dim(i) > 1) { + pslice_sizes[i] = operand.hlo()->shape().dimensions(i); + } + } + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + pshape, operand.hlo(), indices.hlo(), dnums, pslice_sizes, + gather->indices_are_sorted())); + pgather->set_sharding(*maybe_passthrough); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + if (GatherScatterOperandPartitionedOnlyOnTrivialSliceDims( + operand, start_index_map, gather->gather_slice_sizes(), + num_partitions_) && + ShapeUtil::ByteSizeOf(gather->shape()) < + ShapeUtil::ByteSizeOf(gather->operand(0)->shape())) { + indices = indices.Reshard(HloSharding::Replicate()); + // Now the operand is partitioned in trivial slice dimensions, and the + // indices are replicated. We execute a gather on partitioned operand, + // with full number of indices, where out-of-bounds indices are clamped, + // and masked out with 0 in the result; then we use all-reduce to combine + // results. Although gather will not get faster, we avoided the need to + // replicate the operand. + HloInstruction* indices_min; + HloInstruction* indices_max; + std::tie(indices_min, indices_max) = + IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( + operand, indices, partition_id_, start_index_map, + dnums.index_vector_dim(), &b_); + // Clamp the indices. + auto adjusted_indices = b_.AddInstruction(HloInstruction::CreateTernary( + indices.base_shape(), HloOpcode::kClamp, indices_min, indices.hlo(), + indices_max)); + // Adjust the indices by subtracting the offset. + adjusted_indices = b_.AddInstruction(HloInstruction::CreateBinary( + indices.base_shape(), HloOpcode::kSubtract, adjusted_indices, + indices_min)); + // Gather on adjusted indices. + auto pgather = b_.AddInstruction(HloInstruction::CreateGather( + gather->shape(), operand.hlo(), adjusted_indices, dnums, + gather->gather_slice_sizes(), gather->indices_are_sorted())); + // Mask out invalid results. + auto filter = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_min, ComparisonDirection::kLt)); + filter = b_.AddInstruction(HloInstruction::CreateBinary( + filter->shape(), HloOpcode::kOr, filter, + b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(indices.base_shape(), PRED), + indices.hlo(), indices_max, ComparisonDirection::kGt)))); + if (dnums.index_vector_dim() < indices.base_shape().rank()) { + std::vector reduced_filter_dims; + for (int64 i = 0; i < filter->shape().rank(); ++i) { + if (i != dnums.index_vector_dim()) { + reduced_filter_dims.push_back(filter->shape().dimensions(i)); + } + } + filter = b_.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(PRED, reduced_filter_dims), filter, + CreateR0WithType(PRED, false, &b_), {dnums.index_vector_dim()}, + MakeBinaryAdd(PRED, module_))); + } + std::vector batch_dims; + for (int64 i = 0; i < pgather->shape().rank(); ++i) { + if (!absl::c_linear_search(dnums.offset_dims(), i)) { + batch_dims.push_back(i); + } + } + auto broadcast_filter = b_.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(pgather->shape(), PRED), filter, + batch_dims)); + auto filtered = b_.AddInstruction(HloInstruction::CreateTernary( + pgather->shape(), HloOpcode::kSelect, broadcast_filter, + CreateZero(pgather->shape(), &b_), pgather)); + // Combine from different partitions. + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, filtered, + MakeBinaryAdd(filtered->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) { + const auto& tuple = GetPartitionedHlo(hlo->operand(0)); + auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()), + tuple.hlo(), hlo->tuple_index())); + SetPartitionedHlo(hlo, [&]() { + const auto source_sharding = tuple.sharding().GetSubSharding( + tuple.base_shape(), {hlo->tuple_index()}); + gte->set_sharding(source_sharding); + PartitionedHlo source_partitioned_gte(gte, hlo->shape(), + MakePartitioningState()); + return source_partitioned_gte.Reshard(hlo->sharding()).hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) { + const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0); + auto token = GetPartitionedHlo(hlo->operand(0)).hlo(); + if (ShapeUtil::GetLeafCount(shape) == 0) { + // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it + // requires one element for an empty tuple, but leaf-count number of + // elements for non-empty tuple. So if it has a nested empty tuple, we + // cannot invoke GetSubSharding() since it expects a sharding for the empty + // tuple. This is a workaround for that case. + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction( + HloInstruction::CreateInfeed(shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + auto shard_shape = MakePartitionedShape(shape, sharding); + if (EvenlyPartitions(shape, sharding)) { + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateInfeed( + shard_shape, token, hlo->infeed_config())); + }); + return Status::OK(); + } + + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + // Create a branch for each unique partitioned shape. + std::vector per_branch_partitioned_shapes; + std::vector conditional_branch_indices(num_partitions_); + for (int64 i = 0; i < num_partitions_; ++i) { + auto partitioned_shape = + MakeNonPaddedShapeForGivenPartition(shape, sharding, i); + int64 matching_existing_index = 0; + for (; matching_existing_index < per_branch_partitioned_shapes.size(); + ++matching_existing_index) { + if (ShapeUtil::Compatible( + partitioned_shape, + per_branch_partitioned_shapes[matching_existing_index])) { + break; + } + } + if (matching_existing_index < per_branch_partitioned_shapes.size()) { + conditional_branch_indices[i] = matching_existing_index; + } else { + conditional_branch_indices[i] = per_branch_partitioned_shapes.size(); + per_branch_partitioned_shapes.push_back(std::move(partitioned_shape)); + } + } + + HloInstruction* branch_index; + if (per_branch_partitioned_shapes.size() == num_partitions_) { + // Use partition ID as the branch index if each partition has its own + // branch. + branch_index = partition_id_; + // PartitionId's output is U32 but conditional requires S32. + if (branch_index->shape().element_type() != S32) { + branch_index = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(branch_index->shape(), S32), + branch_index)); + } + } else { + // Otherwise, use a constant table to look up the branch index. + auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(conditional_branch_indices))); + branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_}, + {1})); + branch_index = b_.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {}), branch_index)); + } + + std::vector branches(per_branch_partitioned_shapes.size()); + for (int64 i = 0; i < branches.size(); ++i) { + SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_); + auto param = branch_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, token->shape(), "infeed_token_param")); + auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed( + per_branch_partitioned_shapes[i], param, hlo->infeed_config())); + branches[i] = module_->AddEmbeddedComputation(branch_b.Build(infeed)); + if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) { + TF_ASSIGN_OR_RETURN( + auto padded, + branches[i]->DeepCopyInstructionWithCustomCopier( + infeed, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + // Index {1} corresponds to the token. + if (leaf_index.empty() || leaf_index[0] != 0) { + return leaf; + } + ShapeIndexView subindex(leaf_index, 1); + if (ShapeUtil::Compatible( + ShapeUtil::GetSubshape(per_branch_partitioned_shapes[i], + subindex), + ShapeUtil::GetSubshape(shard_shape, subindex))) { + return leaf; + } + return PadToShape(leaf, + ShapeUtil::GetSubshape(shard_shape, subindex), + nullptr, comp); + })); + branches[i]->set_root_instruction(padded, + /*accept_different_shape=*/true); + } + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateConditional( + ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index, + branches, std::vector(branches.size(), token))); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + const auto& pd = hlo->padding_config().dimensions(i); + // Right now we only support non-padded dimensions to be partitioned. + if (hlo->sharding().tile_assignment().dim(i) > 1 && + (pd.edge_padding_high() != 0 || pd.edge_padding_low() != 0 || + pd.interior_padding() != 0)) { + return DefaultAction(hlo); + } + } + auto resharded_lhs = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + auto replicated_rhs = GetPartitionedHlo(hlo->operand(1)) + .Reshard(HloSharding::Replicate()) + .hlo(); + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(hlo->CloneWithNewOperands( + shard_shape, {resharded_lhs, replicated_rhs})); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) { + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto new_param = b_.AddInstruction(HloInstruction::CreateParameter( + hlo->parameter_number(), shard_shape, "param")); + if (hlo->parameter_replicated_at_leaf_buffers()) { + new_param->set_parameter_replicated_at_leaf_buffers( + *hlo->parameter_replicated_at_leaf_buffers()); + } + return new_param; + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { + int64 input_count = 1; + auto per_input_sharding = hlo->sharding(); + if (hlo->shape().IsTuple()) { + input_count = hlo->shape().tuple_shapes_size(); + CHECK_GT(input_count, 0); + per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0}); + } + + std::vector inputs; + std::vector inits; + for (int64 operand_id = 0; operand_id < input_count; ++operand_id) { + inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count)) + .Reshard(HloSharding::Replicate()) + .hlo()); + inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); + if (operand_id > 0) { + // Make sure all operands are sharded in the same way. + inputs.back() = inputs.back().Reshard(inputs[0].sharding()); + } + if (!inputs[0].sharding().IsTileMaximal()) { + inputs.back() = inputs.back().PadWithValue(inits[operand_id]); + } + } + bool reduce_sharded_dimension = false; + if (!inputs[0].sharding().IsTileMaximal()) { + reduce_sharded_dimension = absl::c_any_of(hlo->dimensions(), [&](int64 i) { + return inputs[0].sharding().tile_assignment().dim(i) > 1; + }); + + // reduce_sharded_dimension is not supported for tuple-shaped reduces. + if (reduce_sharded_dimension && input_count > 1) { + return DefaultAction(hlo); + } + + // Currently we only support reducing all or none of the sharded + // dimensions. + if (reduce_sharded_dimension) { + for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { + if (inputs[0].sharding().tile_assignment().dim(i) > 1 && + absl::c_count(hlo->dimensions(), i) == 0) { + return DefaultAction(hlo); + } + } + } + } + + std::vector new_operand_shapes(input_count * 2); + for (int64 i = 0; i < input_count; ++i) { + new_operand_shapes[i] = inputs[i].hlo()->mutable_shape(); + new_operand_shapes[i + input_count] = inits[i]->mutable_shape(); + } + // Create the shard shape of the reduce result. + TF_ASSIGN_OR_RETURN( + auto reduce_shape, + ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(), + hlo->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = hlo->shape().layout(); + + std::vector input_hlos(input_count); + for (int64 i = 0; i < input_count; ++i) { + input_hlos[i] = inputs[i].hlo(); + } + auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply())); + local_reduce->set_metadata(hlo->metadata()); + + SetPartitionedHlo(hlo, [&]() { + HloInstruction* reduce; + if (reduce_sharded_dimension) { + CHECK(local_reduce->shape().IsArray()); + reduce = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, local_reduce, hlo->to_apply(), NewChannel()); + reduce->set_sharding(HloSharding::Replicate()); + } else { + reduce = local_reduce; + if (inputs[0].sharding().IsTileMaximal()) { + reduce->set_sharding(inputs[0].sharding()); + } else { + // Remove tile assignment dimensions that are reduced. + std::vector tile_dimensions; + for (int64 i = 0; i < input_hlos[0]->shape().rank(); ++i) { + if (absl::c_count(hlo->dimensions(), i) == 0) { + tile_dimensions.push_back( + inputs[0].sharding().tile_assignment().dim(i)); + } + } + Array new_tile = inputs[0].sharding().tile_assignment(); + new_tile.Reshape(tile_dimensions); + auto sharding = HloSharding::Tile(new_tile); + if (input_count > 1) { + std::vector tuple(input_count, sharding); + sharding = HloSharding::Tuple(hlo->shape(), tuple); + } + reduce->set_sharding(sharding); + } + } + + return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) { + auto reverse = Cast(hlo); + if (reverse->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + if (absl::c_all_of(reverse->dimensions(), [&](int64 d) { + return reverse->sharding().tile_assignment().dim(d) == 1; + })) { + auto operand = + GetPartitionedHlo(reverse->operand(0)).Reshard(reverse->sharding()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(operand.hlo()->shape(), {operand.hlo()})); + }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) { + const HloSharding& sharding = hlo->sharding(); + + // Shardings for the body parameter, body root, and cond parameter must be + // the same, and the condition root must be replicated so that all partitions + // follow the same control flow. + hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding); + hlo->while_body()->parameter_instruction(0)->set_sharding(sharding); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_condition(), + HloSharding::Replicate(), + next_channel_id_, logger_) + .status()); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(hlo->while_body(), sharding, + next_channel_id_, logger_) + .status()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateWhile( + MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(), + hlo->while_body(), + GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) { + std::vector branch_args; + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + + // Shardings of the branch computation parameter and its argument must be + // the same. + computation->parameter_instruction(0)->set_sharding( + hlo->operand(i + 1)->sharding()); + branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo()); + } + + // The root of the branch computations must follow the sharding of the + // conditional instruction. + for (int64 i = 0; i < hlo->branch_count(); ++i) { + HloComputation* computation = hlo->branch_computation(i); + TF_RETURN_IF_ERROR(partitioner_ + ->PartitionComputation(computation, hlo->sharding(), + next_channel_id_, logger_) + .status()); + } + + // We replicate the predicate of the conditional (the first operand) so that + // all partitions follow the same control flow. + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction(HloInstruction::CreateConditional( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + GetPartitionedHlo(hlo->operand(0)) + .Reshard(HloSharding::Replicate()) + .hlo(), + hlo->called_computations(), branch_args)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) { + TF_RET_CHECK(hlo->sharding().HasUniqueDevice()); + return HandleSingleDevice(hlo); +} + +Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { + if (hlo->sharding().HasUniqueDevice()) { + return HandleSingleDevice(hlo); + } + + if (hlo->sharding().IsReplicated()) { + SetPartitionedHlo(hlo, [&] { + // Run on a single device (0) and distribute the data to all other cores. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::AssignDevice(0)) + .hlo()); + } + auto clone = b_.AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands)); + clone->set_sharding(HloSharding::AssignDevice(0)); + return PartitionedHlo(clone, hlo->shape(), MakePartitioningState()) + .Reshard(HloSharding::Replicate()) + .hlo(); + }); + return Status::OK(); + } + + TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); + SetPartitionedHlo(hlo, [&] { + // Replicate the operands and run partitioned Rng on all devices. + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) + .Reshard(HloSharding::Replicate()) + .hlo()); + } + return b_.AddInstruction(HloInstruction::CreateRng( + MakePartitionedShape(hlo->shape(), hlo->sharding()), + hlo->random_distribution(), new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) { + auto& operand = GetPartitionedHlo(hlo->operand(0)); + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(1)) + .Reshard(HloSharding::Replicate()); + auto resharded_operand_and_window = operand.ReshardAsWindowedInput( + hlo->window(), hlo->sharding(), replicated_init.hlo()); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + + TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape, + ShapeInference::InferReduceWindowShape( + resharded_operand_and_window->sharded_input->shape(), + replicated_init.hlo()->shape(), + resharded_operand_and_window->shard_window, + hlo->to_apply()->ComputeProgramShape())); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_rw_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_rw = b_.AddInstruction(HloInstruction::CreateReduceWindow( + sharded_rw_shape, resharded_operand_and_window->sharded_input, + replicated_init.hlo(), resharded_operand_and_window->shard_window, + hlo->to_apply())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape())); + return sharded_rw; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_rw, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) { + if (hlo->sharding().IsTileMaximal()) { + return DefaultAction(hlo); + } + auto operand = GetPartitionedHlo(hlo->operand(0)); + auto source = GetPartitionedHlo(hlo->mutable_operand(1)); + if (hlo->sharding() != operand.sharding()) { + operand = operand.Reshard(hlo->sharding()); + } + if (hlo->sharding() != source.sharding()) { + source = source.Reshard(hlo->sharding()); + } + + // For F32 and BF16 types, we can use NaN padding to workaround the issue with + // low/high padding, since comparison will return false with NaN input. + if (hlo->shape().element_type() != F32 && + hlo->shape().element_type() != BF16) { + return DefaultAction(hlo); + } + + auto select = hlo->called_computations()[0]; + auto select_root = select->root_instruction(); + if (select_root->opcode() != HloOpcode::kCompare || + select_root->operand(0)->opcode() != HloOpcode::kParameter || + select_root->operand(1)->opcode() != HloOpcode::kParameter || + select_root->operand(0)->parameter_number() + + select_root->operand(1)->parameter_number() != + 1) { + return DefaultAction(hlo); + } + + float float_pad_value; + if (select_root->comparison_direction() == ComparisonDirection::kGe || + select_root->comparison_direction() == ComparisonDirection::kGt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = -std::numeric_limits::infinity(); + } else { + float_pad_value = std::numeric_limits::infinity(); + } + } else if (select_root->comparison_direction() == ComparisonDirection::kLe || + select_root->comparison_direction() == ComparisonDirection::kLt) { + if (select_root->operand(0)->parameter_number() == 0) { + float_pad_value = std::numeric_limits::infinity(); + } else { + float_pad_value = -std::numeric_limits::infinity(); + } + } else { + return DefaultAction(hlo); + } + + auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant( + hlo->shape().element_type() == BF16 + ? LiteralUtil::CreateR0( + static_cast(float_pad_value)) + : LiteralUtil::CreateR0(float_pad_value))); + + // Replicate init + auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2)) + .Reshard(HloSharding::Replicate()); + + auto partition_ordinals = + MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_); + + // The first window for each dimension that overlaps with the shard area. + std::vector first_window( + hlo->shape().rank()); + // The first window for each dimension that goes beyond with the shard area. + std::vector limit_window( + hlo->shape().rank()); + std::vector data_left_halo_sizes(hlo->shape().rank()); + std::vector data_right_halo_sizes(hlo->shape().rank()); + std::vector source_left_halo_sizes(hlo->shape().rank()); + std::vector source_right_halo_sizes(hlo->shape().rank()); + auto unpadded_data_shard_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto unpadded_source_shard_shape = + MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding()); + auto source_shard_hlo = source.hlo(); + auto data_shard_hlo = operand.hlo(); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + // If stride > window_size, there will be gaps between windows. These gaps + // will also exist in the output, so we keep them during halo exchange. + // + // TODO(yuanzx): This could introduce overhead if partitions start at + // different offsets in a gap. + auto wd = hlo->window().dimensions(i); + if (wd.stride() > wd.size()) { + wd.set_size(wd.stride()); + } + // shard_size * i < stride * k - pad_low + window_size => + // k > (shard_size * i + pad_low - window_size) / stride => + // first_k == (shard_size * i + pad_low - window_size + stride) / stride + first_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + wd.padding_low() - wd.size() + wd.stride(), wd.stride()); + // shard_size * (i + 1) <= stride * k - pad_low => + // k >= (shard_size * i + shard_size + pad_low) / stride => + // limit_k == (shard_size * i + shard_size + pad_low + stride - 1) / + // stride + limit_window[i] = MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.padding_low() + + wd.stride() - 1, + wd.stride()); + source_left_halo_sizes[i] = + MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), 0, 1) - + first_window[i]; + source_right_halo_sizes[i] = + limit_window[i] - MultiplyAddDivideOffsetCalculation( + unpadded_source_shard_shape.dimensions(i), + unpadded_source_shard_shape.dimensions(i), 1); + data_left_halo_sizes[i] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) - + OffsetCalculation( + HloOpcode::kMultiply, first_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)); + data_right_halo_sizes[i] = + OffsetCalculation( + HloOpcode::kMultiply, limit_window[i], + MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + unpadded_data_shard_shape.dimensions(i), + unpadded_data_shard_shape.dimensions(i) + wd.stride() + + wd.padding_low() - wd.size(), + 1)); + + int64 max_windows = + (limit_window[i] - first_window[i]).MaxInRange(0, shard_count); + auto first_window_hlo = + first_window[i].Calculate(partition_ordinals[i], &b_); + // Padding on the source is filled with the init value so they do not change + // the data on overlapping windows. + auto resharded_source = ExchangeHaloAndGetValidData( + source_shard_hlo, source.base_shape(), source_left_halo_sizes[i], + source_right_halo_sizes[i], 0, + limit_window[i].Calculate(shard_count - 1), max_windows, i, + hlo->sharding(), first_window_hlo, replicated_init.hlo(), + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_source) { + return DefaultAction(hlo); + } + source_shard_hlo = *resharded_source; + + auto offset_start_in_data = + MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1) + .Calculate(first_window_hlo, &b_); + int64 padded_data_size = + (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() + + wd.size(); + int64 data_shard_size = (max_windows - 1) * wd.stride() + wd.size(); + auto resharded_data = ExchangeHaloAndGetValidData( + data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i], + data_right_halo_sizes[i], wd.padding_low(), padded_data_size, + data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value, + partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_); + if (!resharded_data) { + return DefaultAction(hlo); + } + data_shard_hlo = *resharded_data; + } + + Window window_on_shard = hlo->window(); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + int64 shard_count = hlo->sharding().tile_assignment().dim(i); + if (shard_count == 1) { + continue; + } + auto reshard_wd = window_on_shard.mutable_dimensions(i); + // The shards are already explicitly padded. + reshard_wd->set_padding_low(0); + reshard_wd->set_padding_high(0); + } + + auto sharded_select_and_scatter = + b_.AddInstruction(HloInstruction::CreateSelectAndScatter( + data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard, + source_shard_hlo, replicated_init.hlo(), + hlo->called_computations()[1])); + SetPartitionedHlo(hlo, [&]() { + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(), + shard_shape)) { + return sharded_select_and_scatter; + } + auto zero = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(shard_shape.rank(), zero); + for (int64 i = 0; i < window_on_shard.dimensions_size(); ++i) { + if (hlo->sharding().tile_assignment().dim(i) == 1) { + continue; + } + int64 pad_low = hlo->window().dimensions(i).padding_low(); + auto left_halo_size = + data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_); + if (data_left_halo_sizes[i].Calculate(0) == pad_low) { + slice_offsets[i] = left_halo_size; + } else { + auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i], + ComparisonDirection::kEq)); + auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(pad_low))); + slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary( + zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo, + left_halo_size)); + } + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_select_and_scatter, slice_offsets, + shard_shape.dimensions())); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { + std::vector new_operands; + for (int64 i = 0; i < hlo->operand_count(); ++i) { + new_operands.push_back( + GetPartitionedHlo(hlo->operand(i)) + .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i})) + .hlo()); + } + SetPartitionedHlo(hlo, [&]() { + return b_.AddInstruction(HloInstruction::CreateTuple(new_operands)); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( + HloInstruction* hlo) { + TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution); + + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && + !rhs.sharding().IsTileMaximal()); + + const auto& dnums = hlo->convolution_dimension_numbers(); + + // Check if the operand shardings are aligned. Also we currently don't + // support partitioning non-spatial dimensions. + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != + 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + // Reshard LHS by exchanging halo such that each shard computes the partial + // sum of the full shape result, and add AllReduce. + // + // The size of halo on each dimension can be calculated from the projection + // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers + // to the shard size of RHS and LHS, WC is the number of windows, and D is the + // window dilation. + // + // * offset(i): RHS * D * i - low_padding + // * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding + // + // Since shard i has LHS of range [i * LHS, (i + 1) * LHS) + // * left-halo: i * LHS - offset(i) + // = (LHS - RHS) * i + low_padding + // * right-halo: limit(i) - (i + 1) * LHS + // = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding + + Window window = hlo->window(); + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = + CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = + CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions(hlo->shape().rank()); + std::vector right_halo_size_functions(hlo->shape().rank()); + Window new_window = window; + + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + HloInstruction* lhs_with_halo = lhs.hlo(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + int64 rhs_shard_size_dilated = + (rhs_shard_size - 1) * wd.window_dilation() + 1; + + left_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low, + 1)); + right_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size_dilated - lhs_shard_size, + rhs_shard_size_dilated - lhs_shard_size + + wd.stride() * (window_count - 1) - padding_low, + 1)); + + // Exchange halo and concatenate. + int64 dim = dnums.input_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = padding_low; + int64 shard_size_with_halo = + wd.stride() * (window_count - 1) + rhs_shard_size_dilated; + + new_window.mutable_dimensions(i)->set_padding_low(0); + new_window.mutable_dimensions(i)->set_padding_high(0); + new_window.mutable_dimensions(i)->set_size(rhs_shard_size); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation()); + int64 padded_full_shape_size = 0; + auto concat = ExchangeHaloAndGetValidData( + lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero, + partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_, + /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + lhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, + hlo->convolution_dimension_numbers(), hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); +} + +Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + const HloSharding& sharding = hlo->sharding(); + const auto& dnums = hlo->convolution_dimension_numbers(); + std::vector rhs_to_lhs_indices(hlo->shape().rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(hlo->shape().rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + // Handling cases where both operands' shardings are aligned. We check that + // the LHS batch dimension is not partitioned because it is mapped to the + // output feature dimension in aligned_rhs_sharding, which are not the same + // dimension. + if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) { + if (options_.conv_halo_exchange_always_on_lhs) { + return HandleConvolutionTiledLhsAndRhs(hlo); + } else { + // Reshard RHS so that each shard computes the partial sum of the full + // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() + // that reshards LHS. + // + // The size of halo on each dimension can be calculated from the + // projection onto the RHS that shard i needs to read. RHS and LHS below + // refers to the shard size of RHS and LHS, WC is the number of windows, + // and D is the window dilation. + // + // * offset(i): LHS * i + low_padding - (WC - 1) * stride + // * limit(i): LHS * (i + 1) + low_padding + // + // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D) + // * left-halo: i * RHS - offset(i) + // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding + // * right-halo: limit(i) - (i + 1) * RHS + // = (i + 1) * (LHS - RHS * D) + low_pading + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + // We currently don't support partitioning input batch or output feature + // dimensions. + return lhs_sharding.tile_assignment().dim( + dnums.input_batch_dimension()) != 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return DefaultAction(hlo); + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return DefaultAction(hlo); + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + Window window = hlo->window(); + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return DefaultAction(hlo); + } + + int64 lhs_shard_size = CeilOfRatio( + lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = CeilOfRatio( + rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions( + hlo->shape().rank()); + std::vector right_halo_size_functions( + hlo->shape().rank()); + Window new_window = window; + + // Data structures needed for Pad and DynamicSlice on LHS if needed. + bool need_dynamic_slice_lhs = false; + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); + std::vector zero_padding(hlo->shape().rank()); + PaddingConfig pad_config = + window_util::MakeSymmetricPadding(zero_padding); + auto zero_s32 = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector dynamic_slice_start_indices( + hlo->shape().rank(), zero_s32); + Shape dynamic_slice_shape = lhs.hlo()->shape(); + Shape pad_shape = lhs.hlo()->shape(); + + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. It calculcates the halo sizes with dilation, so we apply + // CeilOfRatio({left,right}_halo_size, window_dilation). + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = + 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + left_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + (window_count - 1) * wd.stride() - padding_low + + wd.window_dilation() - 1, + wd.window_dilation())); + right_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), + lhs_shard_size - rhs_shard_size * wd.window_dilation() + + padding_low + wd.window_dilation() - 1, + wd.window_dilation())); + + // New RHS window size includes the maximum of both left and right + // halos. + int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange( + 1, shard_counts[i]) + + right_halo_size_functions[rhs_dimension].MaxInRange( + 0, shard_counts[i] - 1); + int64 new_window_size = + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size; + + // The amount of new low padding could be dynamic (e.g., window_dilation + // != 1), which requires pad (to the maximum) and dynamic slice on LHS. + // + // If we consider the first window, the offset of the dilated RHS that + // aligns with the first valid LHS element for shard i is 'padding_low + + // LHS * i'. When the left halo is added to RHS, the offset of the first + // RHS element is (RHS * i - left_halo) * window_dilation. The + // difference between the two values is the amount of padding_low we + // need on LHS. + auto new_padding_low_function = + OffsetCalculation( + HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension], + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, wd.window_dilation(), 1))) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + -padding_low, 1)); + + int64 new_padding_low_max = + new_padding_low_function.MaxInRange(0, shard_counts[i]); + int64 new_padding_low = new_padding_low_max; + int64 new_padding_high = window_count * wd.stride() + + (new_window_size - 1) * wd.window_dilation() - + new_padding_low - lhs_shard_size; + + // We do pad/dynamic-slice only when the padding is dynamic. + if (!new_padding_low_function.IsConstant()) { + need_dynamic_slice_lhs = true; + new_padding_low = 0; + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_low(new_padding_low_max); + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_high(new_padding_low_max); + pad_shape.set_dimensions(lhs_dimension, + lhs_shard_size + 2 * new_padding_low_max); + dynamic_slice_start_indices[lhs_dimension] = + (OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, new_padding_low_max, 1)) - + new_padding_low_function) + .Calculate(partition_ordinals[lhs_dimension], &b_); + dynamic_slice_shape.set_dimensions( + lhs_dimension, lhs_shard_size + new_padding_low_max); + } + + // Since the convolution RHS operand size increased with halos, adjust + // the window config accordingly. + new_window.mutable_dimensions(i)->set_padding_low(new_padding_low); + new_window.mutable_dimensions(i)->set_padding_high(new_padding_high); + new_window.mutable_dimensions(i)->set_size( + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size); + } + + HloInstruction* conv_lhs = lhs.hlo(); + if (need_dynamic_slice_lhs) { + auto pad = b_.AddInstruction( + HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config)); + conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, pad, dynamic_slice_start_indices, + dynamic_slice_shape.dimensions())); + } + + // Exchange halo and concatenate. + HloInstruction* rhs_with_halo = rhs.hlo(); + for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { + int64 dim = dnums.kernel_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = + left_halo_size_functions[dim].Calculate(0); + int64 shard_size_with_halo = new_window.dimensions(i).size(); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) - + left_halo_size_functions[dim]; + int64 padded_full_shape_size = + offset_on_padded_shape.Calculate(shard_counts[i] - 1) + + new_window.dimensions(i).size(); + auto concat = ExchangeHaloAndGetValidData( + rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), + zero, partition_ordinals[dim], collective_ops_creator_, + next_channel_id_, &b_, /*mask_invalid_region=*/false); + if (!concat) { + return DefaultAction(hlo); + } + rhs_with_halo = *concat; + } + + SetPartitionedHlo(hlo, [&]() { + auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( + hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums, + hlo->precision_config())); + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } + + if (!sharding.IsTileMaximal()) { + // We don't currently support sharding on output feature dimension. + if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) { + return DefaultAction(hlo); + } + + // Check if the operand and the output sharding are aligned. + std::vector input_to_output_indices(hlo->shape().rank()); + input_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + input_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + input_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + auto target_operand_sharding = + hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices); + lhs = lhs.Reshard(target_operand_sharding); + + // Replicate the RHS. + rhs = rhs.Reshard(HloSharding::Replicate()); + + // Convolution window config does not include batch and feature dimensions, + // whereas ReshardAsWindowedInput() expects the same number of window + // dimensions as the rank of the operand. So add two more trivial + // dimensions. + std::vector ones(hlo->shape().rank(), 1); + auto operand_window = window_util::MakeWindow(ones); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) = + hlo->window().dimensions(i); + } + + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + auto resharded_operand_and_window = lhs.ReshardAsWindowedInput( + operand_window, target_operand_sharding, zero); + if (!resharded_operand_and_window.has_value()) { + return DefaultAction(hlo); + } + Window new_window; + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *new_window.add_dimensions() = + resharded_operand_and_window->shard_window.dimensions( + dnums.input_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + resharded_operand_and_window->sharded_input->shape(), + rhs.hlo()->shape(), hlo->feature_group_count(), + hlo->batch_group_count(), new_window, dnums)); + auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + *sharded_conv_shape.mutable_layout() = shard_shape.layout(); + SetPartitionedHlo(hlo, [&]() { + auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve( + sharded_conv_shape, resharded_operand_and_window->sharded_input, + rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(), + new_window, dnums, hlo->precision_config())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape())); + return sharded_conv; + } + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_conv, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); + }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { + DotGeneralDimsMapping mapping; + const auto& dnums = hlo->dot_dimension_numbers(); + int64 next_output_dim = 0; + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); + mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); + mapping.batch_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); + mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); + mapping.contracting_dims.back().output = -1; + } + for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { + continue; + } + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = i; + mapping.lhs_non_contracting_dims.back().rhs = -1; + mapping.lhs_non_contracting_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { + continue; + } + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = -1; + mapping.rhs_non_contracting_dims.back().rhs = i; + mapping.rhs_non_contracting_dims.back().output = next_output_dim++; + } + auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, + SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_dot_shape, + ShapeInference::InferDotOpShape(l->shape(), r->shape(), + hlo->dot_dimension_numbers())); + return b->AddInstruction(HloInstruction::CreateDot( + sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), + hlo->precision_config())); + }; + return HandleDotHelper(hlo, mapping, create_sharded_dot); +} + +Status SpmdPartitioningVisitor::HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { + const HloSharding& lhs_sharding = hlo->operand(0)->sharding(); + const HloSharding& rhs_sharding = hlo->operand(1)->sharding(); + + // Similar to hlo_sharding_util::TransposeSharding(), but allows + // removing/adding non-partitioned dimensions. + auto transpose_sharding = + [&](const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src) -> absl::optional { + if (source.IsTileMaximal()) { + return source; + } + std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); + int64 skipped_tgt_dims = 0; + for (int64 i = 0; i < tgt_to_src.size(); ++i) { + if (tgt_to_src[i] < 0) { + skipped_tgt_dims++; + } else { + tgt_dims_skipping_new[i] = i - skipped_tgt_dims; + } + } + int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); + std::vector perm(src_to_tgt.size()); + for (int64 i = 0; i < src_to_tgt.size(); ++i) { + if (src_to_tgt[i] < 0) { + if (source.tile_assignment().dim(i) > 1) { + return absl::nullopt; + } + perm[src_to_tgt.size() - skipped_src_dims] = i; + skipped_src_dims--; + } else { + perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; + } + } + auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + if (skipped_tgt_dims == 0) { + return tgt_sharding; + } + auto reshape_tiles = tgt_sharding.tile_assignment(); + std::vector tgt_tiles(tgt_to_src.size(), 1); + for (int64 i = 0; i < tgt_tiles.size(); ++i) { + if (tgt_to_src[i] >= 0) { + tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); + } + } + reshape_tiles.Reshape(tgt_tiles); + return HloSharding::Tile(reshape_tiles); + }; + + std::vector lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1); + std::vector lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1); + std::vector rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1); + std::vector rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1); + std::vector output_to_lhs_indices(hlo->shape().rank(), -1); + std::vector output_to_rhs_indices(hlo->shape().rank(), -1); + auto populate_indices_mapping = + [&](const DotGeneralDimsMapping::DimsMapping& mapping) { + if (mapping.lhs >= 0) { + lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; + lhs_to_output_indices[mapping.lhs] = mapping.output; + } + if (mapping.rhs >= 0) { + rhs_to_lhs_indices[mapping.rhs] = mapping.lhs; + rhs_to_output_indices[mapping.rhs] = mapping.output; + } + if (mapping.output >= 0) { + output_to_lhs_indices[mapping.output] = mapping.lhs; + output_to_rhs_indices[mapping.output] = mapping.rhs; + } + }; + for (const auto& mapping : dims_mapping.batch_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + auto lhs_sharding_transposed_to_match_rhs = + transpose_sharding(lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices); + auto rhs_sharding_transposed_to_match_lhs = + transpose_sharding(rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices); + auto lhs_sharding_transposed_to_match_output = transpose_sharding( + lhs_sharding, lhs_to_output_indices, output_to_lhs_indices); + auto rhs_sharding_transposed_to_match_output = transpose_sharding( + rhs_sharding, rhs_to_output_indices, output_to_rhs_indices); + auto output_sharding_transposed_to_match_lhs = transpose_sharding( + hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices); + auto output_sharding_transposed_to_match_rhs = transpose_sharding( + hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices); + + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. + auto get_partitions_for_dims = + [&](const HloSharding& sharding, + absl::Span dims, + int lhs_rhs_or_output) { + int64 partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_rhs_or_output == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else if (lhs_rhs_or_output == 1) { + partitions *= sharding.tile_assignment().dim(dim.rhs); + } else { + CHECK_EQ(lhs_rhs_or_output, 2); + partitions *= sharding.tile_assignment().dim(dim.output); + } + } + return partitions; + }; + const int64 lhs_batch_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0); + const int64 rhs_batch_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1); + const int64 output_batch_partitions = + get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2); + const int64 lhs_contracting_partitions = + get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0); + const int64 rhs_contracting_partitions = + get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1); + const int64 lhs_non_contracting_partitions = get_partitions_for_dims( + lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0); + const int64 rhs_non_contracting_partitions = get_partitions_for_dims( + rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1); + const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2); + const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( + hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2); + + auto& lhs = GetPartitionedHlo(hlo->operand(0)); + auto& rhs = GetPartitionedHlo(hlo->operand(1)); + // LHS and RHS are partitioned the same way and only partitioned in batch + // dimensions. + if (lhs_batch_partitions == rhs_batch_partitions && + rhs_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_rhs == rhs_sharding) { + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + dot->set_sharding(*lhs_sharding_transposed_to_match_output); + return PartitionedHlo(dot, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // Try emit batch-partitioned einsum with one operand resharded. Returns + // whether the attempt succeeds. If may_reshard_with_allreduce is false, + // reshard must be done using all-to-all; otherwise this attempt fails. + auto try_emit_output_batch_partitioned_einsum_with_reshard = + [&](bool may_reshard_with_allreduce) -> StatusOr { + // LHS and output are batch partitioned in the same way. + if (lhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + lhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(rhs.sharding(), + *lhs_sharding_transposed_to_match_rhs)) { + return false; + } + auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + // RHS and output are batch partitioned in the same way. + if (rhs_batch_partitions == num_partitions_ && + output_batch_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + if (!may_reshard_with_allreduce && + !CanReshardWithAllToAll(lhs.sharding(), + *rhs_sharding_transposed_to_match_lhs)) { + return false; + } + auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return true; + } + return false; + }; + + { + // Try batch-parallel by resharding one operand, and not using all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(false)); + if (emitted) { + return Status::OK(); + } + } + + // Try to emit windowed DotGeneral when one operand is partitioned in the same + // way as the output along non-contracting dimensions, but the other operand + // is tiled in other dimensions. + auto emit_windowed_dot_general = [&](int64 matching_operand, + int64 windowing_operand, + bool windowed_at_contracting_dims, + bool windowed_at_batch_dims) { + CHECK_EQ(matching_operand + windowing_operand, 1); + CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims); + auto unpadded_result_buffer_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + auto padded_result_buffer_shape = unpadded_result_buffer_shape; + // For windowing at batch/non-contracting dims, we produce the result one + // partition at a time, so we need to pad the shape in case of uneven + // partitioning in order to make dynamic-update-slice in-bound. + if (!windowed_at_contracting_dims) { + padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning( + padded_result_buffer_shape, + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output); + } + // Mask the padding area of the windowed operand with zero if there is + // uneven partitioning. + if (windowed_at_contracting_dims) { + auto& to_mask = windowing_operand == 0 ? lhs : rhs; + to_mask = + to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type())))); + } + auto result_buffer = CreateZero(padded_result_buffer_shape, &b_); + auto iteration = b_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + // Create a while loop that computes one window per iteration. During each + // iteration, each partition sends its input window to its neighbor using + // collective-permute for the next iteration. + SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_); + auto param = body_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto l = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0)); + auto r = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1)); + auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), param, 2)); + auto i = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3)); + + auto partition_id = collective_ops_creator_.create_partition_id(&body_b); + auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, partition_id)); + auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))); + data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count)); + auto dot_lhs = l; + auto dot_rhs = r; + if (windowed_at_contracting_dims || windowed_at_batch_dims) { + // Slice the matching operand according to the partitioned contracting + // dimensions on the windowed operand. We do this by treating the matching + // operand as replicated, and resharding it to match the windowed operand. + auto slice_operand = matching_operand == 0 ? l : r; + slice_operand->set_sharding(HloSharding::Replicate()); + auto state = MakePartitioningState(); + state.b = &body_b; + state.partition_id = data_partition_id; + auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) + .Reshard(windowing_operand == 0 + ? *lhs_sharding_transposed_to_match_rhs + : *rhs_sharding_transposed_to_match_lhs) + .hlo(); + slice_operand->clear_sharding(); + if (matching_operand == 0) { + dot_lhs = slice; + } else { + dot_rhs = slice; + } + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(dot_lhs, dot_rhs, &body_b)); + if (windowed_at_contracting_dims) { + // Accumulate the partial output to the result buffer. + o = body_b.AddInstruction( + HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot)); + } else { + // The windowing operand is partitioned along batch/non-contracting + // dimensions, so we need a dynamic-update-slice to save the partial + // output in the result buffer. + auto offsets = MakePartitionOffsets( + o->shape(), + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output, + data_partition_id, &body_b); + o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + o->shape(), o, dot, offsets)); + } + + // ++i + i = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, + body_b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))))); + auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), i, + body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + // Collective-permute for the next window. We don't need it for the last + // iteration, so we use a conditional around the collective-permute. + HloInstruction* conditional; + { + SpmdBuilder cp_b("window_collective_permute", visiting_hlo_); + { + auto p = cp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + std::vector> sd_pairs(num_partitions_); + for (int64 source = 0; source < num_partitions_; ++source) { + // 0 -> n-1, 1 -> 0, 2 -> 1, ... + sd_pairs[source] = {source, + (source - 1 + num_partitions_) % num_partitions_}; + } + collective_ops_creator_.create_cross_partition_collective_permute( + &cp_b, p, sd_pairs, (*next_channel_id_)++); + } + SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_); + { + ncp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + } + conditional = body_b.AddInstruction(HloInstruction::CreateConditional( + windowing_operand == 0 ? l->shape() : r->shape(), has_more, + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(cp_b.Build()), + windowing_operand == 0 ? l : r, + module_->AddEmbeddedComputation(ncp_b.Build()))); + } + if (windowing_operand == 0) { + l = conditional; + } else { + r = conditional; + } + body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i})); + + SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_); + auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement( + iteration->shape(), cond_param, 3)); + cond_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), cond_i, + cond_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions_))), + ComparisonDirection::kLt)); + auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile( + cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()), + module_->AddEmbeddedComputation(body_b.Build()), + b_.AddInstruction(HloInstruction::CreateTuple( + {lhs.hlo(), rhs.hlo(), result_buffer, iteration})))); + windowed_dot_general_loops_.push_back({while_loop, windowing_operand, + windowed_at_contracting_dims, + windowed_at_batch_dims}); + SetPartitionedHlo(hlo, [&] { + auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), while_loop, 2)); + if (!ShapeUtil::Compatible(padded_result_buffer_shape, + unpadded_result_buffer_shape)) { + result = b_.AddInstruction(HloInstruction::CreateSlice( + unpadded_result_buffer_shape, result, + std::vector(padded_result_buffer_shape.rank(), 0), + unpadded_result_buffer_shape.dimensions(), + std::vector(padded_result_buffer_shape.rank(), 1))); + } + return result; + }); + return Status::OK(); + }; + if (output_lhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_lhs == lhs_sharding && + ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (rhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, true, false); + } + if (rhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, false); + } + if (rhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(0, 1, false, true); + } + } + if (output_rhs_non_contracting_partitions == num_partitions_ && + output_sharding_transposed_to_match_rhs == rhs_sharding && + ShapeUtil::ByteSizeOf(hlo->operand(0)->shape()) >= + options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (lhs_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, true, false); + } + if (lhs_non_contracting_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, false); + } + if (lhs_batch_partitions == num_partitions_) { + return emit_windowed_dot_general(1, 0, false, true); + } + } + + { + // Try batch-parallel by resharding one operand, and allowing all-reduce. + TF_ASSIGN_OR_RETURN( + bool emitted, + try_emit_output_batch_partitioned_einsum_with_reshard(true)); + if (emitted) { + return Status::OK(); + } + } + + // LHS and RHS have the same partitioned contracting dimensions. + if (lhs_contracting_partitions == rhs_contracting_partitions && + lhs_contracting_partitions == num_partitions_) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + // Pad both sides with zero, since NaN at one side cannot be masked by zero + // on the other side. + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) < + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + + // LHS and output have the same partitioned non-contracting dimensions. + if (lhs_non_contracting_partitions == num_partitions_ && + output_lhs_non_contracting_partitions == num_partitions_ && + lhs_sharding == hlo->sharding()) { + auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs_replicated, &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // RHS and output have the same partitioned non-contracting dimensions. + if (rhs_non_contracting_partitions == num_partitions_ && + output_rhs_non_contracting_partitions == num_partitions_ && + rhs_sharding_transposed_to_match_output == hlo->sharding()) { + auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs_replicated, rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Output is batch partitioned. + if (output_batch_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along LHS non-contracting dimensions. + if (output_lhs_non_contracting_partitions == num_partitions_) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + TF_ASSIGN_OR_RETURN( + auto dot, + create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + // Output is partitioned along RHS non-contracting dimensions. + if (output_rhs_non_contracting_partitions == num_partitions_) { + auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { return dot; }); + return Status::OK(); + } + + // Returns true if it is beneficial to reshard the operand at `operand_idx` + // across the contracting dimension. + const auto should_partition_contracting_dim = [&](int64 operand_idx) { + if (!hlo->sharding().IsReplicated()) { + return false; + } + + if (operand_idx == 0) { + // If LHS and output are replicated, we compare the cost of all-gather + // on RHS vs all-reduce on the output. + return (rhs_contracting_partitions == num_partitions_) && + lhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(1)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } else { + return (lhs_contracting_partitions == num_partitions_) && + rhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(hlo->operand(0)->shape()) > + ShapeUtil::ElementsIn(hlo->shape()); + } + }; + + // When the output is replicated and one of the operands is partitioned along + // contracting dimension, align the other operand to be partitioned along + // the contracting dimensions. + if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) || + should_partition_contracting_dim(1))) { + auto zero = b_.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + if (should_partition_contracting_dim(0)) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); + SetPartitionedHlo(hlo, [&] { + auto ar = collective_ops_creator_.create_cross_partition_all_reduce( + &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), + NewChannel()); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo(); + }); + return Status::OK(); + } + + return DefaultAction(hlo); +} + +namespace { + +// Finds a cluster of nodes that produce the inputs for `hlo` which only depend +// on small operands, which means the cluster should start with broadcasts, +// constants and iotas. All other internal nodes must be non-side-effecting +// elemntwise ops. Returns the set of nodes, and the small operands. E.g., for +// the following graph, +// +// a -> broadcast -> multiply +// iota ---> add--/ +// constant/ +// +// FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return +// <{broadcast, iota, constant, add, multiply}, [a]>. +std::pair, std::vector> +FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) { + std::unordered_set nodes_found; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector worklist; + worklist.push_back(hlo); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (nodes_found.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast || + inst->opcode() == HloOpcode::kConstant || + inst->opcode() == HloOpcode::kIota) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + auto res = new_operands_set.emplace(o); + if (res.second) { + new_operands.push_back(o); + } + } + } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + worklist.push_back(o); + } + } else { + nodes_found.clear(); + new_operands.clear(); + break; + } + } + return {std::move(nodes_found), std::move(new_operands)}; +} + +// Moves a cluster of memory-reducing nodes into the windowed dot-general loop +// on contracting dimensions. Such a loop has a dynamic slice on the +// non-windowed operand. If we move the input nodes into the loop, the +// dynamic-slice could be merged with them by later optimization passes, which +// reduces memory. +// +// small_operands small_operands +// | | +// input_nodes loop { | +// | => input_nodes +// loop { | | +// dynamic-slice dynamic-slice +// ... ... +// } } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes. +Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + HloInstruction* loop, int64 non_windowed_operand_index) { + auto input_tuple = loop->mutable_operand(0); + auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index); + auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand); + auto to_sink = std::move(input_nodes.first); + auto new_operands = std::move(input_nodes.second); + if (to_sink.empty()) { + return Status::OK(); + } + auto computation = loop->parent(); + // Replace the old operand with a tuple of the found small operands. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_input_subtuple)); + + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto old_body_param_users = body_param->users(); + // Update all tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body->root_instruction()}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), + {non_windowed_operand_index}) = + new_input_subtuple->shape(); + } + // Now update the loop body. + auto new_operand_tuple_inside = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, non_windowed_operand_index)); + TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_operand_tuple_inside)); + + // Create nodes inside the loop body. + std::vector worklist; + std::unordered_map outside_to_inside; + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_operand_tuple_inside, i)); + add_users_if_available(new_operands[i]); + } + // HLOs to sink without operands. + std::vector nullaries_to_sink; + for (auto inst : to_sink) { + if (inst->operand_count() == 0) { + nullaries_to_sink.push_back(inst); + } + } + // Sort nullaries_to_sink to make it deterministic. + absl::c_sort(nullaries_to_sink, + [](const HloInstruction* a, const HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + for (auto inst : nullaries_to_sink) { + worklist.push_back(inst); + } + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + std::vector inst_new_operands(inst->operand_count()); + for (int64 i = 0; i < inst->operand_count(); ++i) { + inst_new_operands[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction( + inst->CloneWithNewOperands(inst->shape(), inst_new_operands)); + add_users_if_available(inst); + } + TF_RET_CHECK(outside_to_inside.count(old_operand) > 0); + for (auto ou : old_body_param_users) { + if (ou->opcode() == HloOpcode::kGetTupleElement && + ou->tuple_index() == non_windowed_operand_index) { + TF_RETURN_IF_ERROR( + ou->ReplaceAllUsesWith(outside_to_inside[old_operand])); + TF_RETURN_IF_ERROR(body->RemoveInstruction(ou)); + } + } + return Status::OK(); +} + +// Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into +// the windowed dot-general loop on non-contracting dimensions. Such a loop has +// a dynamic-update-slice at the output. If we move the user nodes into the loop +// and before the dynamic-update-slice, the user nodes can operate on smaller +// shapes, which reduces memory. +// +// small_operands small_operands +// | | => | | +// | | loop { loop { | | +// | | conv | broadcast conv +// | | | | | / +// | | dynamic-update-slice | dynamic-slice / +// | | | | | / +// | | } | | multiply----- +// |broadcast / | / +// | | / reduce +// |multiply-- | +// \ | dynamic-update-slice +// reduce } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes (broadcast). +Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + HloInstruction* loop) { + CHECK_EQ(loop->user_count(), 1); + // There should be a single direct user of the while loop, which is the + // gte for element 2, i.e., the dot output. + auto user_gte = loop->users().front(); + CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement); + CHECK_EQ(user_gte->tuple_index(), 2); + auto computation = loop->parent(); + + // Find the reduce outputs and the input nodes they depend on, if input nodes + // only have small operands. + std::unordered_set to_move; + std::vector new_operands; + std::unordered_set new_operands_set; + std::vector reduce_outputs; + std::vector worklist; + Shape padded_shape = user_gte->shape(); + Shape unpadded_shape = user_gte->shape(); + auto original_output = user_gte; + + if (user_gte->user_count() == 1 && + user_gte->users().back()->opcode() == HloOpcode::kSlice) { + original_output = user_gte->users().back(); + unpadded_shape = original_output->shape(); + } + for (auto u : original_output->users()) { + worklist.push_back(u); + } + to_move.insert(original_output); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (to_move.count(inst) > 0) { + continue; + } + // We only support reduces with simple reduction function, since we may need + // to accumulate across iterations manually. + if (inst->opcode() == HloOpcode::kReduce && + inst->to_apply()->instruction_count() == 3 && + inst->to_apply()->num_parameters() == 2 && + inst->to_apply()->root_instruction()->IsElementwise()) { + to_move.insert(inst); + auto other_operand = inst->mutable_operand(1); + auto res = new_operands_set.emplace(other_operand); + if (res.second) { + new_operands.push_back(other_operand); + } + reduce_outputs.push_back(inst); + } else if (inst != computation->root_instruction() && + inst->user_count() > 0 && inst->IsElementwise() && + !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + // For an elementwise op, we need to make sure that they depend on only + // nodes already in to_move and nodes with small operands. + bool can_include = true; + for (auto operand : inst->operands()) { + if (to_move.count(operand) > 0) { + continue; + } + auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand); + if (find_result.first.empty()) { + can_include = false; + break; + } + for (auto n : find_result.first) { + to_move.insert(n); + } + for (auto new_operand : find_result.second) { + auto res = new_operands_set.insert(new_operand); + if (res.second) { + new_operands.push_back(new_operand); + } + } + } + if (!can_include) { + to_move.clear(); + break; + } + to_move.insert(inst); + for (auto u : inst->users()) { + worklist.push_back(u); + } + } else { + to_move.clear(); + break; + } + } + // If nothing is found, to_move could contain only original_output, or cleared + // by the above code. + if (to_move.size() <= 1) { + return Status::OK(); + } + + // We will replace the original loop output with reduce-shape outputs. Create + // the initial buffers before the loop. + for (auto out : reduce_outputs) { + auto padded_out_shape = out->shape(); + int64 operand_dim = 0; + int64 output_dim = 0; + while (output_dim < padded_out_shape.rank()) { + if (absl::c_linear_search(out->dimensions(), operand_dim)) { + // Dimension colapsed. + ++operand_dim; + continue; + } + // Kept dimensions have the same size of the padded shape. + padded_out_shape.set_dimensions(output_dim, + padded_shape.dimensions(operand_dim)); + ++operand_dim; + ++output_dim; + } + auto broadcast = + computation->AddInstruction(HloInstruction::CreateBroadcast( + padded_out_shape, + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(out->shape().element_type()))), + {})); + new_operands.push_back(broadcast); + } + + auto input_tuple = loop->mutable_operand(0); + // Create the new input subtuple that contains the small operands and the + // reduce-shape result buffers. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple)); + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto body_root = body->root_instruction(); + CHECK_EQ(body_root->opcode(), HloOpcode::kTuple); + // Update tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body_root}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) = + new_input_subtuple->shape(); + } + auto new_loop_input = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, 2)); + + // Now create the moved nodes inside the loop body. + std::unordered_map outside_to_inside; + worklist.clear(); + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_loop_input, i)); + add_users_if_available(new_operands[i]); + } + // The elementwise nodes will be created with sliced shape. The original loop + // output corresponds to the dynamic-update-slice's update slice. + auto dus = body_root->mutable_operand(2); + CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice); + outside_to_inside[original_output] = dus->mutable_operand(1); + add_users_if_available(original_output); + std::vector slice_offsets(padded_shape.rank()); + for (int64 i = 0; i < slice_offsets.size(); ++i) { + slice_offsets[i] = dus->mutable_operand(i + 2); + } + auto get_slice = [&](HloInstruction* padded) { + return body->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + padded->shape().element_type()), + padded, slice_offsets, dus->operand(1)->shape().dimensions())); + }; + // Helper functions to create nodes with small operands. + auto add_broadcast = [&](const HloInstruction* broadcast) { + auto padded_operand_shape = broadcast->operand(0)->shape(); + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + padded_operand_shape.set_dimensions( + i, padded_shape.dimensions(broadcast->dimensions(i))); + } + auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)], + padded_operand_shape, nullptr, body); + outside_to_inside[broadcast] = + get_slice(body->AddInstruction(broadcast->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + padded_operand_shape.element_type()), + {padded_operand}))); + }; + auto add_iota = [&](const HloInstruction* iota) { + outside_to_inside[iota] = + get_slice(body->AddInstruction(iota->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + iota->shape().element_type()), + {}))); + }; + auto add_constant = [&](const HloInstruction* constant) { + outside_to_inside[constant] = body->AddInstruction(constant->Clone()); + outside_to_inside[constant] = get_slice( + PadToShape(outside_to_inside[constant], + ShapeUtil::ChangeElementType( + padded_shape, constant->shape().element_type()), + nullptr, body)); + }; + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (outside_to_inside.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast) { + add_broadcast(inst); + } else if (inst->opcode() == HloOpcode::kIota) { + add_iota(inst); + } else if (inst->opcode() == HloOpcode::kConstant) { + add_constant(inst); + } else if (inst->opcode() == HloOpcode::kReduce) { + // This is an output, for which we has special handling later. + } else { + std::vector operands_inside(inst->operand_count()); + for (int64 i = 0; i < operands_inside.size(); ++i) { + operands_inside[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + inst->shape().element_type()), + operands_inside)); + } + add_users_if_available(inst); + } + std::vector new_outputs_inside(new_operands.size()); + for (int64 i = 0; i < new_outputs_inside.size(); ++i) { + new_outputs_inside[i] = outside_to_inside[new_operands[i]]; + } + // Now create the reduce outpus inside of the loop. + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + auto reduce_outside = reduce_outputs[i]; + CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce); + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto last_iter_result = outside_to_inside[new_operands[index_in_operand]]; + auto operand0 = outside_to_inside[reduce_outside->operand(0)]; + auto operand1 = outside_to_inside[reduce_outside->operand(1)]; + TF_ASSIGN_OR_RETURN(auto reduce_shape, + ShapeInference::InferReduceShape( + {&operand0->shape(), &operand1->shape()}, + reduce_outside->dimensions(), + reduce_outside->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = reduce_outside->shape().layout(); + std::vector reduce_dus_offsets; + // If any collapsed dimension is windowed, we need to accumulate with last + // iteration's result. If such a dimension has padding, we also need to mask + // off invalid data. + bool needs_accumulate = false; + std::vector dims_to_mask; + for (int64 i = 0; i < slice_offsets.size(); ++i) { + if (absl::c_linear_search(reduce_outside->dimensions(), i)) { + if (reduce_outside->operand(0)->shape().dimensions(i) != + operand0->shape().dimensions(i)) { + needs_accumulate = true; + if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) { + dims_to_mask.push_back(i); + } + } + continue; + } + reduce_dus_offsets.push_back(slice_offsets[i]); + } + // Mask off invalid data in collapsed dimensions. + for (int64 dim : dims_to_mask) { + auto iota = body->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::ChangeElementType(operand0->shape(), S32), dim)); + auto add = body->AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, + body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), slice_offsets[dim], {})))); + auto limit = body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), + body->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + reduce_outside->operand(0)->shape().dimensions(dim)))), + {})); + auto compare = body->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit, + ComparisonDirection::kLt)); + operand0 = body->AddInstruction(HloInstruction::CreateTernary( + operand0->shape(), HloOpcode::kSelect, compare, operand0, + body->AddInstruction(HloInstruction::CreateBroadcast( + operand0->shape(), operand1, {})))); + } + auto output_inside = + body->AddInstruction(reduce_outside->CloneWithNewOperands( + reduce_shape, {operand0, operand1})); + // Accumulate with previous results if needed. + if (needs_accumulate) { + auto input_slice = + body->AddInstruction(HloInstruction::CreateDynamicSlice( + output_inside->shape(), last_iter_result, reduce_dus_offsets, + output_inside->shape().dimensions())); + output_inside = body->AddInstruction(HloInstruction::CreateBinary( + output_inside->shape(), + reduce_outside->to_apply()->root_instruction()->opcode(), + output_inside, input_slice)); + } + // Dynamic-update-slice if needed. + if (!ShapeUtil::Compatible(output_inside->shape(), + last_iter_result->shape())) { + output_inside = + body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + last_iter_result->shape(), last_iter_result, output_inside, + reduce_dus_offsets)); + } + new_outputs_inside[index_in_operand] = output_inside; + } + // Body output. + auto new_output_inside = + body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside)); + TF_RETURN_IF_ERROR( + body_root->ReplaceOperandWithDifferentShape(2, new_output_inside)); + TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus)); + // Replace uses of the reduces outside the loop. + auto new_output_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_output_inside->shape(), loop, 2)); + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto new_output = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_outputs_inside[index_in_operand]->shape(), new_output_gte, + index_in_operand)); + if (!ShapeUtil::Compatible(new_output->shape(), + reduce_outputs[i]->shape())) { + new_output = computation->AddInstruction(HloInstruction::CreateSlice( + reduce_outputs[i]->shape(), new_output, + std::vector(new_output->shape().rank(), 0), + reduce_outputs[i]->shape().dimensions(), + std::vector(new_output->shape().rank(), 1))); + } + TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output)); + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i])); + } + return Status::OK(); +} + +} // namespace + +Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops( + HloComputation* computation) { + for (auto& loop : windowed_dot_general_loops_) { + if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims) { + // We have a dynamic-slice for the non-windowed operand in + // batch/contracting-dim windowed dot-general. So moving the + // broadcast/iota/elementwise ops into the loop could help reduce memory + // via fusion. + TF_RETURN_IF_ERROR( + SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + loop.while_loop, 1 - loop.windowed_operand)); + } + if (!loop.windowed_in_contracting_dims) { + // We have a dynamic-update-slice for the output in + // batch/non-contracting-dim windowed dot-general. So moving reduce ops + // into the loop could help reduce memory. + TF_RETURN_IF_ERROR( + MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + loop.while_loop)); + } + } + return Status::OK(); +} + +StatusOr SpmdPartitioningVisitor::DoPartition( + HloComputation* computation, const HloSharding& root_sharding) { + VLOG(2) << "Partitioning computation " << computation->name() << " for " + << num_replicas_ << " replicas and " << num_partitions_ + << " partitions"; + TF_RETURN_IF_ERROR(computation->Accept(this)); + + HloModule* module = computation->parent(); + auto new_root = + GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding); + auto new_computation = + module->AddEmbeddedComputation(b_.Build(new_root.hlo())); + TF_RETURN_IF_ERROR(DoCodeMotionForWindowedDotGeneralLoops(new_computation)); + + // Replace the original computation with the new SPMD computation. + std::unordered_map replacement; + replacement[computation] = new_computation; + module->ReplaceComputations(replacement); + return changed_; +} + +Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) { + return Unimplemented( + "PartitionId instruction is not supported for SPMD partitioning since " + "the meaning is ambiguous -- whether the instruction is replicated or " + "the data is replicated, and if the latter which data is replicated."); +} + +SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options) + : SpmdPartitioner( + num_partitions, num_replicas, std::move(options), + SPMDCollectiveOpsCreator{ + [](SpmdBuilder* b) { + return b->AddInstruction(HloInstruction::CreatePartitionId()); + }, + [num_replicas](SpmdBuilder* b, HloInstruction* operand, + HloComputation* reduction, int64 channel_id) { + return b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction, + CreateReplicaGroups(num_replicas), + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + }, + [](SpmdBuilder* b, HloInstruction* operand, + std::vector>& src_dst_pairs, + int64 channel_id) { + return b->AddInstruction( + HloInstruction::CreateCollectivePermute( + operand->shape(), operand, src_dst_pairs, channel_id)); + }, + [](SpmdBuilder* b, absl::Span operands, + const std::vector& replica_groups, + int64 channel_id, absl::optional split_dimension) { + std::vector shapes(operands.size(), + operands[0]->shape()); + const Shape output_shape = + (shapes.size() == 1) ? shapes[0] + : ShapeUtil::MakeTupleShape(shapes); + return b->AddInstruction(HloInstruction::CreateAllToAll( + output_shape, operands, replica_groups, + /*constrain_layout=*/false, channel_id, split_dimension)); + }, + }) {} + +StatusOr SpmdPartitioner::PartitionComputation( + HloComputation* computation, const HloSharding& root_sharding, + int64* next_channel_id, SpmdLogger* logger) { + auto visitor = + CreateVisitor(computation, num_partitions_, num_replicas_, + collective_ops_creator_, next_channel_id, logger, options_); + return visitor->DoPartition(computation, root_sharding); +} + +std::unique_ptr SpmdPartitioner::CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options) { + return absl::make_unique( + computation, num_partitions, num_replicas, collective_ops_creator, + next_channel_id, logger, std::move(options), this); +} + +StatusOr SpmdPartitioner::Run(HloModule* module) { + TF_RETURN_IF_ERROR(PreprocessSharding(module)); + + XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition( + *module, options_.report_instruction_count)); + + // Add the parameters' and output's shardings to the module. + std::vector entry_params_shardings; + for (int64 i = 0; i < module->entry_computation()->num_parameters(); ++i) { + auto param = module->entry_computation()->parameter_instruction(i); + CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i; + entry_params_shardings.push_back(param->sharding()); + } + module->set_spmd_parameters_shardings(entry_params_shardings); + auto entry_root = module->entry_computation()->root_instruction(); + CHECK(entry_root->has_sharding()) << "Missing sharding in entry root."; + module->set_spmd_output_sharding(entry_root->sharding()); + + FlattenCallGraph flatten; + TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module)); + + SpmdLogger logger(options_.report_instruction_count); + auto program_shape = module->entry_computation()->ComputeProgramShape(); + int64 next_channel_id = hlo_query::NextChannelId(*module); + TF_ASSIGN_OR_RETURN( + bool partition_changed, + PartitionComputation( + module->entry_computation(), + module->entry_computation()->root_instruction()->sharding(), + &next_channel_id, &logger)); + changed |= partition_changed; + + // For the entry computation, make sure that the root instruction and the + // parameters preserve their signatures. + auto new_program_shape = module->entry_computation()->ComputeProgramShape(); + if (!options_.allow_module_signature_change) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.result(), new_program_shape.result())) + << "Result shape changed for the entry computation"; + TF_RET_CHECK(program_shape.parameters_size() == + new_program_shape.parameters_size()) + << "Parameter count changed for the entry computation"; + for (int64 i = 0; i < program_shape.parameters_size(); ++i) { + TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( + program_shape.parameters(i), new_program_shape.parameters(i))) + << "Parameter shape changed for the entry computation"; + } + } else { + const auto& old_entry_layout = module->entry_computation_layout(); + // Shapes can change but the layout should still remain the same. + for (int64 i = 0; i < new_program_shape.parameters_size(); ++i) { + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.parameter_shape(i), + new_program_shape.mutable_parameters(i))); + } + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.result_shape(), new_program_shape.mutable_result())); + + HloModuleConfig config = module->config(); + *config.mutable_entry_computation_layout() = + ComputationLayout(new_program_shape, /*ignore_layouts=*/false); + module->set_config(config); + } + + XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition( + *module, options_.report_instruction_count)); + XLA_VLOG_LINES(1, logger.MakeReport()); + + if (changed) { + HloPassPipeline pass("spmd-cleanup"); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(/*is_layout_sensitive=*/true); + pass.AddPass(); + TF_RETURN_IF_ERROR(pass.Run(module).status()); + } + + TF_RETURN_IF_ERROR(ClearShardingAttributes(module)); + return changed; +} + +Status SpmdPartitioner::PreprocessSharding(HloModule* module) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* hlo : computation->instructions()) { + if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) { + TF_RET_CHECK(hlo->has_sharding()) + << "Side-effect HLO must have sharding: " << hlo->ToString(); + TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) || + hlo->opcode() == HloOpcode::kInfeed) + << "Non-infeed side-effect HLO cannot have a replicated sharding:" + << hlo->ToString(); + } + + // For unassigned HLOs, annotate with replicated sharding. + // + // Among side-effecting ops, only Rng is allowed to omit the annotation. + // In that case, we currently force it to run on core 0, since we don't + // support partitioning or replicating the Rng op (the values depend on + // the seed provided to each device). + // + // TODO(hyouklee): Should we also convert single-device shardings (without + // side-effects) into replicated? + if (!hlo->has_sharding()) { + if (hlo->opcode() == HloOpcode::kRng) { + hlo->set_sharding(HloSharding::AssignDevice(0)); + } else { + hlo->set_sharding( + HloSharding::Single(hlo->shape(), HloSharding::Replicate())); + } + } else if (!hlo->sharding().IsTileMaximal()) { + std::vector available(num_partitions_); + std::iota(available.begin(), available.end(), 0); + TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding( + hlo->sharding(), available) + .size()) + << "num_partitions:" << num_partitions_ << "\n" + << "SPMD partitioner only supports tile sharding that includes all " + "partitions. If you didn't add this sharding annotation in the " + "model, please file a bug to XLA team.\n" + << hlo->ToString(); + } + } + } + + // Entry computation's parameter and root sharding must be either all + // replicated or all on a single device. + if (!options_.allow_module_signature_change) { + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry->root_instruction()->has_sharding()); + const HloSharding& root_sharding = entry->root_instruction()->sharding(); + TF_RET_CHECK(root_sharding.IsReplicated() || + root_sharding.UniqueDevice().has_value()) + << "Unsupported entry root sharding: " << root_sharding.ToString(); + + for (const HloInstruction* param : entry->parameter_instructions()) { + TF_RET_CHECK(param->has_sharding()); + TF_RET_CHECK(param->sharding().IsReplicated() || + param->sharding().UniqueDevice().has_value()) + << "Unsupported entry parameter sharding:" + << param->sharding().ToString(); + } + } + + return Status::OK(); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h new file mode 100644 index 00000000000..f22f564be73 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -0,0 +1,436 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +namespace xla { +namespace spmd { + +struct SpmdPartitionerOptions { + // Always exchange halo on LHS for all convolutions. If false, backprop filter + // convolution exchanges halo on RHS. + bool conv_halo_exchange_always_on_lhs = true; + + // The number of instructions to be reported for the highest memory profile + // instructions. + int64 report_instruction_count = 5; + + // The minimum size in MiB of an einsum operand to be considered using + // windowed implementation in an HLO loop. + int64 threshold_for_windowed_einsum_mib = 256; + + // Whether the entry computations' signature could change after partitioning. + bool allow_module_signature_change = false; +}; + +// Class to wrap the computation builder to capture information during SPMD +// transformation. +class SpmdBuilder : public HloComputation::Builder { + public: + SpmdBuilder(const std::string& name, HloInstruction* hlo) + : HloComputation::Builder(name) { + visiting_hlo_ = hlo; + } + HloInstruction* AddInstruction(std::unique_ptr instruction); + + const std::vector& derived_instructions( + HloInstruction* hlo) { + return instructions_.at(hlo); + } + + void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; } + + HloInstruction* visiting_hlo() const { return visiting_hlo_; } + + private: + // Currently visiting instruction. + HloInstruction* visiting_hlo_; + + // Map from the currently visiting (old) instruction to new instructions + // created during SPMD partitioning. + HloInstructionMap> instructions_; +}; + +// A set of functions that create the cross-partition collective ops. +struct SPMDCollectiveOpsCreator { + // Function used to create a partition ID HLO. + std::function create_partition_id; + + // Function used to create a cross-partition all-reduce HLO. + std::function + create_cross_partition_all_reduce; + + // Function used to create a cross-partition collective-permute HLO. + std::function>& src_dst_pairs, + int64 next_channel_id)> + create_cross_partition_collective_permute; + + // Function used to create a cross-partition all-to-all HLO. + std::function operands, + const std::vector& replica_groups, int64 channel_id, + absl::optional split_dimension)> + create_cross_partition_all_to_all; +}; + +// Logger to report memory usage during SPMD partitioning. +class SpmdLogger { + public: + explicit SpmdLogger(int64 report_instruction_count) + : report_instruction_count_(report_instruction_count) {} + static std::string ReportBeforePartition(const HloModule& module, + int64 report_instruction_count); + static std::string ReportAfterPartition(const HloModule& module, + int64 report_instruction_count); + + // Registers the logging for the groups of instructions created to transform + // the given hlo. + void RegisterLogEntry(HloInstruction* hlo, + const std::vector& group); + + std::string MakeReport(); + + private: + template + static std::string ReportMemoryUsage(const HloModule& module, const F& filter, + int64 report_instruction_count); + + // A vector of logging messages (one for each original HLO instruction), where + // the first integer of the pair represents the size of the HBM used. + std::vector> entries_; + + int64 report_instruction_count_; +}; + +class SpmdPartitioningVisitor; + +class SpmdPartitioner : public HloModulePass { + public: + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options); + SpmdPartitioner(int64 num_partitions, int64 num_replicas, + SpmdPartitionerOptions options, + SPMDCollectiveOpsCreator collective_ops_creator) + : num_partitions_(num_partitions), + num_replicas_(num_replicas), + options_(std::move(options)), + collective_ops_creator_(std::move(collective_ops_creator)) {} + absl::string_view name() const override { return "spmd-partitioning"; } + StatusOr Run(HloModule* module) override; + + // Transforms the given computation with SPMD instructions, replacing it with + // a new computation. + StatusOr PartitionComputation(HloComputation* computation, + const HloSharding& root_sharding, + int64* next_channel_id, + SpmdLogger* logger); + + protected: + virtual std::unique_ptr CreateVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options); + + private: + // Verify that the sharding of instructions in the module are valid, and also + // fill in missing sharding information. + Status PreprocessSharding(HloModule* module); + + const int64 num_partitions_; + const int64 num_replicas_; + + SpmdPartitionerOptions options_; + SPMDCollectiveOpsCreator collective_ops_creator_; +}; + +// Class describes partition state of the data represented by an HLO created +// during SPMD partitioning pass. +// +// Data on some devices may include padding region, if the base (full) shape +// could not be evenly partitioned. +class PartitionedHlo { + public: + // Return value for ReshardAsWindowedInput which describes the resharded HLO, + // the window for the user on the shard, and if necessary, the dynamic slice + // offsets to be applied to the output of the op being sharded. + struct WindowedInputShardReturnValue { + HloInstruction* sharded_input; + Window shard_window; + absl::optional> dynamic_slice_index_on_output; + }; + // A cache for resharding each partitioned HLO. + struct ReshardCache { + struct PerHloCache { + std::vector> reshard_cache; + std::vector< + std::tuple> + window_reshard_cache; + }; + std::unordered_map per_hlo_cache; + }; + struct PartitioningState { + SpmdBuilder* b; + HloModule* module; + int64 num_replicas; + HloInstruction* partition_id; + SPMDCollectiveOpsCreator collective_ops_creator; + int64* next_channel_id; + ReshardCache* reshard_cache; + }; + PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) + : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { + CHECK(hlo->has_sharding()) + << "PartitionedHlo is missing sharding:" << hlo->ToString(); + // If the tuple shape instruction does not have a tuple sharding, reassign + // to use the tuple sharding. Reshard() implementation assumes this. + if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { + hlo_->set_sharding( + hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); + } + } + + // Reshards the current SPMD instruction to a new sharding. Could only modify + // the reshard cache. + PartitionedHlo Reshard(const HloSharding& target); + + // Pads the garbage area of the output with the provided value. + PartitionedHlo PadWithValue(HloInstruction* pad_value) const; + + // Returns the SPMD instruction. + HloInstruction* hlo() const { return hlo_; } + + // Returns the sharding of the SPMD instruction. + const HloSharding& sharding() const { return hlo_->sharding(); } + + // Original full shape of the data. + const Shape& base_shape() const { return base_shape_; } + + int64 NewChannel() const { return (*state_.next_channel_id)++; } + + // Reshards the HLO to a usable partitioned input for a windowed user. Could + // only modify the reshard cache. + absl::optional ReshardAsWindowedInput( + const Window& window, const HloSharding& target, + HloInstruction* pad_value, bool mask_invalid_region = true); + + private: + // Same as Reshard except that it does not explicitly modify the reshard + // cache, although it would indirectly modify by calling Replicate(). + PartitionedHlo ReshardNoCache(const HloSharding& target); + + // Helper function to replicate the data on all devices. Could only modify + // the reshard cache. + PartitionedHlo Replicate(); + + // Helper function to broadcast data from a single device to all devices. + PartitionedHlo Broadcast() const; + + // Helper function to reshard the tensor using AllToAll (instead of the + // default of Replicate followed by Slice). + PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const; + + // Helper function to reshard the tensor using CollectivePermute. + PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; + + // SPMD instruction. + HloInstruction* hlo_; + + // The original shape of the data before SPMD transformation is applied. + Shape base_shape_; + + PartitioningState state_; +}; + +struct DotGeneralDimsMapping { + // The dimension numbers for the operands and output corresponding to a + // logical dimension (e.g., batch, contracting, non-contracting). If an + // operand or the output doesn't have the logical dimension, it is set to + // -1. + struct DimsMapping { + int64 lhs; + int64 rhs; + int64 output; + }; + std::vector batch_dims; + std::vector contracting_dims; + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; +}; + +class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { + public: + SpmdPartitioningVisitor( + HloComputation* computation, int64 num_partitions, int64 num_replicas, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdLogger* logger, + SpmdPartitionerOptions options, SpmdPartitioner* partitioner); + + Status DefaultAction(HloInstruction* hlo) override; + Status HandleAllReduce(HloInstruction* hlo) override; + Status HandleBroadcast(HloInstruction* hlo) override; + Status HandleConstant(HloInstruction* hlo) override; + Status HandleCustomCall(HloInstruction* hlo) override; + Status HandleDot(HloInstruction* hlo) override; + Status HandleDynamicSlice(HloInstruction* hlo) override; + Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; + Status HandleGather(HloInstruction* hlo) override; + Status HandleGetTupleElement(HloInstruction* hlo) override; + Status HandleInfeed(HloInstruction* hlo) override; + Status HandleOutfeed(HloInstruction* hlo) override; + Status HandlePad(HloInstruction* hlo) override; + Status HandleParameter(HloInstruction* hlo) override; + Status HandleReduce(HloInstruction* hlo) override; + Status HandleReverse(HloInstruction* hlo) override; + Status HandleWhile(HloInstruction* hlo) override; + Status HandleConditional(HloInstruction* hlo) override; + Status HandleReduceWindow(HloInstruction* hlo) override; + Status HandleSelectAndScatter(HloInstruction* hlo) override; + Status HandleTuple(HloInstruction* hlo) override; + Status HandleRng(HloInstruction* hlo) override; + Status HandleConvolution(HloInstruction* hlo) override; + Status HandleConcatenate(HloInstruction* hlo) override; + Status HandleScatter(HloInstruction* hlo) override; + Status HandleSlice(HloInstruction* hlo) override; + Status HandleSort(HloInstruction* hlo) override; + Status HandleTranspose(HloInstruction* hlo) override; + Status HandleReshape(HloInstruction* hlo) override; + Status HandleIota(HloInstruction* hlo) override; + Status HandlePartitionId(HloInstruction* hlo) override; + + // Handles convolution where both LHS and RHS operands are tiled. + Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo); + + // Implementation of dot partitioning given DotGeneralDimsMapping. + Status HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot); + + // Common handle for elementwise HLOs. + Status HandleElementwise(HloInstruction* hlo); + + // Common handle for HLOs that runs on a single device. + Status HandleSingleDevice(const HloInstruction* hlo); + + // Returns the PartitionedHlo that corresponds to the original hlo. + PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 1); + return partitioned_instructions_.find(hlo)->second; + } + + // Sets the PartitionedHlo for the original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const PartitionedHlo& partitioned_hlo) { + CHECK_EQ(partitioned_instructions_.count(hlo), 0); + partitioned_instructions_.emplace(hlo, partitioned_hlo); + changed_ = true; + } + + // Convenient wrapper that creates PartitionedHlo from the result of the func + // and maps it to the given original hlo. + void SetPartitionedHlo(const HloInstruction* hlo, + const std::function& func) { + HloInstruction* new_hlo = func(); + new_hlo->set_sharding(hlo->sharding()); + new_hlo->set_metadata(hlo->metadata()); + SetPartitionedHlo( + hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); + changed_ = true; + } + + int64 NewChannel() { return (*next_channel_id_)++; } + + PartitionedHlo::PartitioningState MakePartitioningState() { + PartitionedHlo::PartitioningState state; + state.b = &b_; + state.module = module_; + state.num_replicas = num_replicas_; + state.partition_id = partition_id_; + state.collective_ops_creator = collective_ops_creator_; + state.next_channel_id = next_channel_id_; + state.reshard_cache = &reshard_cache_; + return state; + } + + SpmdBuilder* builder() { return &b_; } + + StatusOr DoPartition(HloComputation* computation, + const HloSharding& root_sharding); + + private: + Status Preprocess(HloInstruction* hlo) override; + Status Postprocess(HloInstruction* hlo) override; + + // Performs code motion for windowed dot-general loops in + // windowed_dot_general_loops_. Invoked after the visitor finishes traversing + // the graph. + Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation); + + bool changed_; + HloModule* module_; + int64 num_partitions_; + int64 num_replicas_; + + SPMDCollectiveOpsCreator collective_ops_creator_; + + // Tracks the next channel id to use for cross-partition all-reduce. + int64* next_channel_id_; + SpmdBuilder b_; + + HloInstruction* partition_id_; + + PartitionedHlo::ReshardCache reshard_cache_; + + // Mapping from the instruction in the original computation to the new SPMD + // partitioned instruction. + ConstHloInstructionMap partitioned_instructions_; + + // Information about a loop created for windowed dot-general. Used when + // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor + // finishes traversing the graph. + struct WindowedDotGeneralLoop { + HloInstruction* while_loop; + int64 windowed_operand; + bool windowed_in_contracting_dims; + bool windowed_in_batch_dims; + }; + std::vector windowed_dot_general_loops_; + + HloInstruction* visiting_hlo_; + SpmdLogger* logger_; + const SpmdPartitionerOptions options_; + SpmdPartitioner* partitioner_; +}; + +} // namespace spmd +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc new file mode 100644 index 00000000000..ca1afc816b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -0,0 +1,3215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace spmd { +namespace { + +using ::testing::_; +using ::testing::AllOf; +namespace op = xla::testing::opcode_matchers; + +class SpmdPartitioningTest : public HloTestBase { + public: + StatusOr> PartitionComputation( + const char* hlo_module, int64 num_devices, + bool conv_halo_exchange_always_on_lhs = true) { + // Some tests (BackpropFilter convs) set this flag false to test two + // different paths of the implementation. + SpmdPartitionerOptions options; + options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs; + options.allow_module_signature_change = true; + + TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( + hlo_module, GetModuleConfigForTest())); + HloPassPipeline pass("spmd-partitioning"); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + pass.AddPass(num_devices, /*num_replicas=*/1, options); + pass.AddPass(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); + return StatusOr>(std::move(module)); + } +}; + +TEST_F(SpmdPartitioningTest, InvalidSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4); + EXPECT_FALSE(module_status.status().ok()); + EXPECT_THAT(module_status.status().ToString(), + ::testing::HasSubstr( + "only supports tile sharding that includes all partitions")); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce( + op::Select(op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]"))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + VLOG(1) << module->ToString(); + EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare()), + op::Constant(), op::Broadcast()))), + op::Shape("s32[2,3]")))); +} + +TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={maximal device=0} + ROOT %copy = s32[2,3]{1,0} copy(%constant), + sharding={devices=[2,1]1,0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Copy(op::DynamicSlice( + op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Constant(), op::Broadcast())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())), + op::Shape("s32[1,3]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]"))))); +} + +TEST_F(SpmdPartitioningTest, TiledToSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::Copy(op::Copy(op::AllReduce(AllOf( + op::DynamicUpdateSlice( + op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant()), + op::Shape("s32[2,3]")))))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledEven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf( + op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))), + op::Shape("s32[8,1]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1} + ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll( + op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]"))))))))))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0), + sharding={{maximal device=1}, {maximal device=1}} + %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0, + sharding={maximal device=0} + %gte.1 = u32[] get-tuple-element(%param.0), index=1, + sharding={maximal device=0} + ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1), + sharding={{maximal device=0},{maximal device=0}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + EXPECT_THAT(root->operand(0), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); + EXPECT_THAT(root->operand(1), + op::Copy(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::GetTupleElement(op::Parameter()), op::Broadcast())))); +} + +TEST_F(SpmdPartitioningTest, GetTupleElementTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0), + sharding={{replicated}, {replicated}} + gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0, + sharding={devices=[2,1]0,1} + gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1, + sharding={devices=[2,1]0,1} + ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1), + sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Tuple()); + + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + + EXPECT_THAT(root->operand(0), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); + EXPECT_THAT(root->operand(1), + op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, + op::Constant())); +} + +TEST_F(SpmdPartitioningTest, TiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[8,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0, + sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), + op::GetTupleElement( + AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), + op::Constant())), + op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[9,2]{1,0}, token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {maximal device=0}} + ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), + op::AfterAll(), op::AfterAll())))); + EXPECT_THAT( + root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter()))); + auto second_infeed = + AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter())); + EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("(f32[5,2], token[])"), + op::Tuple(op::Pad(op::GetTupleElement(second_infeed), + op::Constant()), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0), + sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}} + ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed), + index=0, sharding={{devices=[2,1]0,1}, {replicated}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"), + op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), op::AfterAll(), + op::AfterAll())))); + EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Infeed(op::Parameter()))); + auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"), + op::Infeed(op::Parameter())); + EXPECT_THAT( + root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("((f32[5,2], f32[2]), token[])"), + op::Tuple(op::Tuple(op::Pad(op::GetTupleElement( + op::GetTupleElement(second_infeed)), + op::Constant()), + op::GetTupleElement( + op::GetTupleElement(second_infeed))), + op::GetTupleElement(second_infeed)))); +} + +TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1}, + to_apply=sum, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + op::AllReduce(op::Reduce( + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())), + op::Broadcast(op::Constant())), + AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant())), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledElementwise) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}), + sharding={replicated} + multiply = f32[3,3]{1,0} multiply(constant, constant.1), + sharding={devices=[2,1]0,1} + ROOT add = f32[3,3]{1,0} add(multiply, constant.1), + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Shape("f32[2,3]{1,0}"), + op::Add(op::Multiply( + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant()), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, TiledAllReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum, + replica_groups={}, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0)))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"), + op::Broadcast(op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}), + sharding={replicated} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2,3]{2,1,0}"), + op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), + op::Constant()))))); +} + +TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}), + sharding={devices=[2,1]0,1} + ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"), + op::Broadcast(op::DynamicSlice( + op::Constant(), op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token.0 = token[] after-all() + data = f32[1024]{0} parameter(0), sharding={maximal device=0} + outfeed = token[] outfeed(data, token.0), sharding={maximal device=0} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("token[]"), + op::Conditional( + op::Compare(op::PartitionId(), op::Constant()), + op::Tuple(op::Parameter(0), op::AfterAll()), + op::Tuple(op::Parameter(0), op::AfterAll())))); + + HloInstruction* root_b0 = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(root_b0, + AllOf(op::Shape("token[]"), + op::Outfeed(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1)))); + + HloInstruction* root_b1 = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll())); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={replicated} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow( + op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"), + op::Pad(op::Constant(), op::Constant())), + op::Multiply(op::Reshape(), op::Constant()), + op::Constant()), + op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1), + window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = + op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[9,2]{1,0} constant( + {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}), + sharding={devices=[3,1]0,1,2} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1), + window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum, + sharding={devices=[3,1]0,1,2} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/3)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = op::DynamicSlice( + AllOf( + op::Shape("f32[7,2]{1,0}"), + op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())), + op::Reshape(), op::Constant()); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}), + sharding={devices=[2,1]0,1} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1), + window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()); + auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto pre_masking = AllOf( + op::Shape("f32[5,2]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[6,2]{1,0}"), + op::Pad(op::Concatenate(left_halo, sharded_input, right_halo), + op::Constant())), + op::Reshape(), op::Constant())); + auto index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto masked = op::Select( + op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())), + op::Compare(index_in_padded, op::Broadcast(op::Constant()))), + pre_masking, op::Broadcast(op::Constant())); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"), + op::ReduceWindow(masked, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0), + sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}} + infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0, + sharding={devices=[2,2,1,1]0,1,2,3} + constant = f32[] constant(0), sharding={replicated} + ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant), + window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum, + sharding={devices=[2,2,1,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + + auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"), + op::GetTupleElement(op::Infeed())); + auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(sharded_input))); + auto dim0_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"), + op::Pad( + op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo), + op::Constant())), + op::Reshape(), op::Constant(), op::Constant(), op::Constant()); + auto dim0_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim0_masked = op::Select( + op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))), + dim0_pre_masking, op::Broadcast(op::Constant())); + auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked); + auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_right_halo = + AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"), + op::CollectivePermute(op::Slice(dim0_resharded))); + auto dim1_pre_masking = op::DynamicSlice( + AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"), + op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded, + dim1_right_halo), + op::Constant())), + op::Constant(), op::Reshape(), op::Constant(), op::Constant()); + auto dim1_index_in_padded = op::Add( + op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant()))); + auto dim1_masked = op::Select( + op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())), + op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))), + dim1_pre_masking, op::Broadcast(op::Constant())); + auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked); + EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"), + op::ReduceWindow(dim1_resharded, op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs), + sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs), + sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution( + f32[128,224,224,3] %lhs.copy, + f32[7,7,3,64] %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,224,224,3]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]")); + auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)), + op::Shape("f32[128,112,224,3]")); + + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, reshard_lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[224,224,3,128] parameter(0) + %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1} + %rhs = f32[7,7,3,64] parameter(1) + %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 stride=2x2 pad=3_3x3_3}, + dim_labels=01fb_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[112,224,3,128]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[3,224,3,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[2,224,3,128]")); + EXPECT_THAT(root, + AllOf(op::Convolution( + op::Select(op::And(), + op::Concatenate(left_halo, lhs, right_halo), + op::Broadcast()), + rhs), + op::Shape("f32[128,56,112,64]"))); +} + +// (stride * per_shard_window_count) % dilation == 0 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + // There is no halo exchange, and because the last element in the shard is not + // needed (stride == 4), the LHS will be just a slice. + auto sliced_lhs = + AllOf(op::Slice(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant()))), + op::Shape("f32[128,3,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs), + op::Shape("f32[128,2,4,512]"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 1); +} + +// (stride * per_shard_window_count) % dilation != 0 but stride == 1 +TEST_F(SpmdPartitioningTest, + ConvolutionBaseDilationStride1LhsTiledRhsReplicated) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,7,7,512] parameter(0) + %lhs.copy = f32[128,7,7,512] copy(%lhs), + sharding={devices=[1,2,1,1]0,1} + %rhs = f32[3,3,512,512] parameter(1) + %rhs.copy = f32[3,3,512,512] copy(%rhs), + sharding={replicated} + ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy), + window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, + dim_labels=b01f_01io->b01f, + sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Shape("f32[128,4,7,512]")); + auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,7,512]")); + auto start_window = op::Multiply(op::Reshape(), op::Constant()); + auto start_input_element = op::Divide(start_window, op::Constant()); + auto dynamic_offset_for_padded_concat = op::Subtract( + op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + start_input_element)); + auto pre_masking = + AllOf(op::Shape("f32[128,5,7,512]"), + op::DynamicSlice( + AllOf(op::Shape("f32[128,6,7,512]"), + op::Pad(op::Concatenate(left_halo, lhs), op::Constant())), + op::Constant(), dynamic_offset_for_padded_concat, + op::Constant(), op::Constant())); + auto masked = op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)), + op::Broadcast(op::Constant())), + pre_masking, op::Broadcast(op::Constant())); + auto dynamic_offset_on_output = op::Subtract( + start_window, op::Multiply(start_input_element, op::Constant())); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs), + op::Shape("f32[128,8,14,512]")), + op::Constant(), dynamic_offset_on_output, + op::Constant(), op::Constant()), + op::Shape("f32[128,7,14,512]"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant())), + op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[1,4]0,1,2,3} + constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto source = + AllOf(op::Shape("f32[1,2]{1,0}"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto operand = AllOf(op::Copy(op::DynamicSlice( + op::Parameter(0), op::Constant(), op::Reshape())), + op::Shape("f32[11,1]")); + auto reshard_operand = op::Reshape(op::Transpose( + op::AllToAll(op::Reshape(op::Pad(operand, op::Constant()))))); + auto masked_data = AllOf( + op::Shape("f32[3,4]{1,0}"), + op::Select( + op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply( + op::Reshape(), op::Constant()))), + op::Broadcast(op::Constant())), + reshard_operand, op::Broadcast(op::Constant()))); + + EXPECT_THAT(root, + AllOf(op::SelectAndScatter(masked_data, source, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param = f32[11,4]{1,0} parameter(0) + %param.copy = f32[11,4] copy(%param), + sharding={devices=[4,1]0,1,2,3} + constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}), + sharding={devices=[4,1]0,1,2,3} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy, + constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0}, + select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + + auto source_shard = + AllOf(op::Shape("f32[2,2]{1,0}"), + op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant())); + // Max halo size is the same as the shard size, so slice is not needed. + auto source_left_halo = op::CollectivePermute(source_shard); + auto required_source_shard_start = + op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto source_with_halo = op::DynamicSlice( + AllOf(op::Shape("f32[5,2]{1,0}"), + op::Pad(op::Concatenate(source_left_halo, source_shard), + op::Constant())), + op::Subtract(op::Constant(), + op::Subtract(op::Multiply(op::Reshape(), op::Constant()), + required_source_shard_start)), + op::Constant()); + auto masked_source_with_halo = AllOf( + AllOf(op::Shape("f32[3,2]{1,0}")), + op::Select( + op::Compare( + op::Add(op::Iota(), op::Broadcast(required_source_shard_start)), + op::Broadcast(op::Constant())), + source_with_halo, op::Broadcast(op::Constant()))); + + auto data_shard = + AllOf(op::Shape("f32[3,4]{1,0}"), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Reshape(), op::Constant()))); + auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"), + op::CollectivePermute(op::Slice(data_shard))); + auto required_data_start_on_padded = + op::Multiply(required_source_shard_start, op::Constant()); + auto left_halo_size = op::Subtract( + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()), + required_data_start_on_padded); + auto data_with_halo = + AllOf(op::Shape("f32[7,4]{1,0}"), + op::DynamicSlice( + AllOf(op::Shape("f32[8,4]{1,0}"), + op::Pad(op::Concatenate(data_left_halo, data_shard, + data_right_halo), + op::Constant())), + op::Subtract(op::Constant(), left_halo_size), op::Constant())); + auto index_on_padded = + op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded)); + auto masked_data_with_halo = op::Select( + op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())), + op::Compare(index_on_padded, op::Broadcast(op::Constant()))), + data_with_halo, op::Broadcast(op::Constant())); + + EXPECT_THAT( + root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo, + masked_source_with_halo, + op::Constant()), + left_halo_size, op::Constant()), + op::Shape("f32[3,4]{1,0}"))); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,64] parameter(0) + %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,56,56,256] parameter(1) + %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy), + window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,56,56,256]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]")); + auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all))); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)), + op::Shape("f32[1,1,64,256]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,512] parameter(0) + %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,64] parameter(1) + %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, + dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,512]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[64,28,28,64]")); + auto all_to_all = + AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]")); + auto reshard = op::Reshape(op::Transpose(all_to_all)); + + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)), + op::Shape("f32[1,1,512,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[32,1,28,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[32,16,28,64]")))), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,2,112,64]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo), + op::Shape("f32[128,60,112,64]")))), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)), + op::Shape("f32[128,1,7,512]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()), + op::Constant(), op::Subtract(), + op::Constant(), op::Constant()), + op::Shape("f32[128,10,14,512]")), + AllOf(op::Concatenate(left_halo, rhs), + op::Shape("f32[128,5,7,512]")))), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,28,28,128] parameter(0) + %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,28,28,64] parameter(1) + %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,128]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,14,28,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[32,1,28,128]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[32,16,28,128]")), + rhs)), + op::Shape("f32[3,3,128,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,224,224,3] parameter(0) + %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,112,112,64] parameter(1) + %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy), + window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,112,224,3]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,56,112,64]")); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,3,224,3]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,2,224,3]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::Concatenate(left_halo, lhs, right_halo), + op::Shape("f32[128,117,224,3]")), + rhs)), + op::Shape("f32[7,7,3,64]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,56,56,256] parameter(0) + %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,28,28,512] parameter(1) + %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy), + window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,28,56,256]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,14,28,512]")); + + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)), + op::Shape("f32[1,1,256,512]"))); +} + +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,14,14,512] parameter(0) + %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[128,7,7,512] parameter(1) + %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy), + window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[128,7,14,512]")); + auto rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[128,4,7,512]")); + + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[128,1,14,512]")); + EXPECT_THAT( + root, AllOf(op::AllReduce(op::Convolution( + AllOf(op::DynamicSlice( + AllOf(op::Pad(op::Concatenate(lhs, right_halo), + op::Constant()), + op::Shape("f32[128,10,14,512]")), + op::Constant(), op::Reshape(), op::Constant(), + op::Constant()), + op::Shape("f32[128,9,14,512]")), + rhs)), + op::Shape("f32[3,3,512,512]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,257]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[7,116]")); + EXPECT_THAT(root, + AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]"))); +} + +TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[14,257] parameter(0) + %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1} + %param1 = f32[14,116] parameter(1) + %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1} + ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy), + dimensions={1}, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Reshape())), + op::Shape("f32[14,129]")); + auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[14,58]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::DynamicUpdateSlice( + op::DynamicUpdateSlice( + op::Broadcast(), param0, + op::Constant(), op::Multiply()), + param1, op::Constant(), op::Add())), + op::Shape("f32[14,374]")), + op::Constant(), op::Multiply()), + op::Shape("f32[14,187]"))); +} + +TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + %const = f32[] constant(0) + ROOT %pad = f32[128,17,257] pad(%param0.copy, %const), padding=0_0x1_2x0_0, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()), + op::Shape("f32[128,17,129]"))); +} + +TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[128,11,257] slice(%param0.copy), + slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]"))); +} + +TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %slice = f32[63,14,251] slice(%param0.copy), + slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()), + op::Constant(), op::Constant(), op::Reshape())), + op::Shape("f32[128,14,129]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf( + op::DynamicSlice( + AllOf(op::Concatenate( + param0, + AllOf(op::CollectivePermute(op::Slice(param0)), + op::Shape("f32[128,14,2]"))), + op::Shape("f32[128,14,131]")), + op::Constant(), op::Constant(), op::Add()), + op::Shape("f32[128,14,126]"))), + op::Shape("f32[63,14,126]"))); +} + +TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ge { + p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated} + bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + constant = s32[]{:T(256)} constant(0), sharding={replicated} + compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated} + constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated} + bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated} + subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated} + bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated} + select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated} + p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated} + bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated} + bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated} + subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated} + bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated} + select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated} + compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated} + compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated} + compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated} + p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated} + p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated} + compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated} + ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated} +} + +ENTRY entry { + %param0 = f32[128,14,257] parameter(0) + %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1} + %param1 = s32[128,14,257] parameter(1) + %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1} + ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)}) + sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true, + to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,7,257]")); + auto param1 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("s32[128,7,257]")); + EXPECT_THAT(root, AllOf(op::Sort(param0, param1), + op::Shape("(f32[128,7,257], s32[128,7,257])"))); +} + +TEST_F(SpmdPartitioningTest, PartitionCustomCall) { + const char* const hlo_string = R"( +HloModule cluster_2013453984438090939__.47 + +ENTRY %cluster_2013453984438090939__.47 + (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) { + %arg_tuple.1 = bf16[2,209664] parameter(0) + %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1} + %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK" + %get-tuple-element = bf16[2,2000]{1,0} + get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call), + index=0, sharding={replicated} + %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0}, + s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated} + ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0}) + tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0} + %get-tuple-element.1), sharding={{replicated}, {replicated}}, + metadata={op_name="XLA_Retvals"} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto custom_call = FindInstruction(module.get(), "custom-call.1"); + EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832); + auto sort = FindInstruction(module.get(), "sort"); + EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000); + EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000); +} + +TEST_F(SpmdPartitioningTest, ShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]"))); +} + +TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7} + ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy), + dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))), + op::Shape("f32[16,38,38,2]")); + EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); +} + +TEST_F(SpmdPartitioningTest, ShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1} + ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[2,1,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[19,38,324]")); + EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); +} + +TEST_F(SpmdPartitioningTest, NonShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[1,1,2]0,1} + ROOT %transpose = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[1,1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::DynamicSlice( + AllOf(op::Pad( + AllOf(op::Reshape(AllOf(op::AllReduce(), + op::Shape("f32[38,38,324]"))), + op::Shape("f32[38,38,4,81]")), + op::Constant()), + op::Shape("f32[38,38,4,82]")), + op::Constant(), op::Constant(), op::Constant(), op::Reshape()), + op::Shape("f32[38,38,4,41]"))); +} + +TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1} + ROOT %reshape = s32[3,2,1,14,5] reshape(%input), + sharding={devices=[1,1,1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto reshape = + AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]")); + auto halo = op::CollectivePermute(op::Slice(reshape)); + auto exchanged = + op::DynamicSlice(op::Concatenate(halo, reshape), _, _, _, _, _); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]"))); +} + +// Produces an invalid module after transformation. +TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[128,5,5,768] parameter(0) + %param0.copy = f32[128,5,5,768] copy(%param0), + sharding={devices=[1,4,1,1]0,1,2,3} + %constant.1 = f32[] constant(0), sharding={replicated} + ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1), + window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1}, + to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input_shard = op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())); + auto id_mul4_add1 = + op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()); + auto id_mul5 = op::Multiply(op::Reshape(), op::Constant()); + auto id_mul5_add1_div3 = + op::Divide(op::Add(id_mul5, op::Constant()), op::Constant()); + auto before_masking = AllOf( + op::Shape("f32[128,3,5,768]"), + op::DynamicSlice( + AllOf( + op::Shape("f32[128,4,5,768]"), + op::Concatenate(op::CollectivePermute(input_shard), input_shard)), + op::Constant(), + op::Subtract(op::Constant(), + op::Subtract(id_mul4_add1, id_mul5_add1_div3)), + op::Constant(), op::Constant())); + auto masked = op::Select( + op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant())), + op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)), + op::Broadcast(op::Constant()))), + before_masking, op::Broadcast(op::Constant())); + auto rw = AllOf(op::Shape("f32[128,7,17,768]"), + op::ReduceWindow(masked, op::Constant())); + auto final_slice_index = op::Subtract( + id_mul5, + op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant())); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("f32[128,5,17,768]"), + op::DynamicSlice(rw, op::Constant(), final_slice_index, + op::Constant(), op::Constant()))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,1,1,2]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[4,32,32,64]")); + + EXPECT_THAT(root, + AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1} + %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1} + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func, + sharding={{devices=[2]0,1}, {devices=[2]0,1}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1), + op::Parameter(2), op::Parameter(3)), + op::Shape("(f32[14], s32[14])"))); +} + +TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,32,32,128] parameter(0) + %param0.copy = f32[4,32,32,128] copy(%param0), + sharding={devices=[1,2,1,1]0,1} + %constant.1 = f32[] constant(0), sharding={replicated} + %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2}, + to_apply=%sum, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[4,16,32,128]")); + + EXPECT_THAT(root, + AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Reduce(param0, op::Constant())), + op::Shape("f32[128]")), + op::Reshape()), + op::Shape("f32[64]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=1, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = s32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("s32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + ROOT %iota = u32[16,80,91] iota(), iota_dimension=2, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()), + op::Shape("u32[16,80,46]"))); +} + +TEST_F(SpmdPartitioningTest, Conditional) { + const char* const hlo_string = R"( +HloModule module + +Negate { + x = f32[4,5] parameter(0), sharding={replicated} + ROOT negate = f32[4,5] negate(x), sharding={replicated} +} + +Identity { + y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1} + ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1} +} + +ENTRY entry { + %param.0 = pred[] parameter(0) + %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0} + %param.1 = f32[4,5] parameter(1) + %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated} + %param.2 = f32[4,5] parameter(2) + %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1} + ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy), + true_computation=Negate, false_computation=Identity, + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]"))); + auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]")); + auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[2,5]")); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2), + op::Shape("f32[2,5]"))); + + auto then_branch_root = root->branch_computation(0)->root_instruction(); + EXPECT_THAT(then_branch_root, + AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(), + op::Constant()), + op::Shape("f32[2,5]"))); + + auto else_branch_root = root->branch_computation(1)->root_instruction(); + EXPECT_THAT(else_branch_root, + AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) { + const char* const hlo_string = R"( +HloModule module + +ge { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT compare = pred[] compare(a, b), direction=GE +} + +sum { + c = f32[] parameter(0) + d = f32[] parameter(1) + ROOT add = f32[] add(c, d) +} + +ENTRY entry { + %param.0 = f32[32,128,384,64] parameter(0) + %param.0.copy = f32[32,128,384,64] copy(%param.0), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + %param.1 = f32[32,64,192,64] parameter(1) + %param.1.copy = f32[32,64,192,64] copy(%param.1), + sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} + constant.1 = f32[] constant(0), sharding={replicated} + ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy, + %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1}, + select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto source = AllOf( + op::Shape("f32[32,8,192,64]"), + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + auto data = AllOf( + op::Shape("f32[32,16,384,64]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant()))); + + EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant())); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 0); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); +} + +TEST_F(SpmdPartitioningTest, TiledDot) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/false)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]"))); +} + +TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,64] parameter(0) + %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1} + %rhs = f32[64,256] parameter(1) + %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy), + dim_labels=bf_io->bf, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), + op::Reshape())), + op::Shape("f32[128,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant())), + op::Shape("f32[32,256]")); + EXPECT_THAT(root, AllOf(op::DynamicSlice( + AllOf(op::AllReduce(op::Convolution(lhs, rhs)), + op::Shape("f32[128,256]")), + op::Constant(), op::Reshape()), + op::Shape("f32[128,128]"))); +} + +TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[128,256,256] parameter(0) + %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[256,8,1] parameter(1) + %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated} + ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy), + window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[128,128,256]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]")); + EXPECT_THAT(root, + AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]"))); +} + +TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,64] parameter(0) + %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated} + %rhs = f32[39296,64] parameter(1) + %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1} + ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant())), + op::Shape("f32[19648,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,24,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(), + op::Constant(), + op::Constant())), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,12,64]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,39296,64]")); + auto lhs_reshard = op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))); + EXPECT_THAT(root, + AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,24,64]")); + auto rhs_slice = + AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[16,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice), + op::Shape("f32[16,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,24,32,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), + op::Constant(), op::Reshape(), op::Reshape())), + op::Shape("f32[32,39296,32,64]")); + EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)), + op::Shape("f32[32,24,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,12,64,64]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]")); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(), + op::Constant(), op::Reshape())), + op::Shape("f32[32,19648,64,64]")); + EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,12,64,128]")), + rhs), + op::Shape("f32[32,12,39296]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]")); + auto rhs = + AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]")); + EXPECT_THAT(root, + AllOf(op::Dot(lhs, AllOf(op::DynamicSlice( + rhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant()), + op::Shape("f32[32,19648,64,128]"))), + op::Shape("f32[32,24,19648]"))); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,64,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,19648,64,128]")); + EXPECT_THAT( + root, + AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))), + op::Shape("f32[32,12,39295]"))); + auto while_loop = root->operand(0)->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)), + partial_output, op::Constant(), + op::Constant(), op::Reshape()), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,63,128] parameter(0) + %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39296,63,128] parameter(1) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[32,12,63,128]")); + auto rhs = + AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()), + op::Constant(), op::Constant(), + op::Reshape(), op::Constant())), + op::Shape("f32[32,39296,32,128]")); + auto masked_rhs = + op::Select(op::Compare(), rhs, op::Broadcast(op::Constant())); + EXPECT_THAT(root, + AllOf(op::GetTupleElement(op::While(op::Tuple( + lhs, masked_rhs, op::Broadcast(), op::Constant()))), + op::Shape("f32[32,12,39296]"))); + auto while_loop = root->operand(0); + // Check loop condition. + EXPECT_THAT( + while_loop->while_condition()->root_instruction(), + op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant())); + + // Check loop body. + auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant()); + auto window = op::Conditional(op::Compare(next_i, op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0))); + auto partial_output = op::Dot( + op::DynamicSlice( + op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), + op::Constant(), op::Constant(), op::Reshape(), op::Constant()), + op::GetTupleElement(op::Parameter(0))); + EXPECT_THAT( + while_loop->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(0)), window, + op::Add(op::GetTupleElement(op::Parameter(0)), partial_output), + next_i)); + + // Check the conditional that contains the collective permute. + auto cp_conditional = + while_loop->while_body()->root_instruction()->operand(1); + EXPECT_THAT(cp_conditional->true_computation()->root_instruction(), + op::CollectivePermute(op::Parameter(0))); + EXPECT_THAT(cp_conditional->false_computation()->root_instruction(), + op::Parameter(0)); +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2}, + to_apply=sum, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39295,64,128] parameter(1) + %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1} + %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} + %constant = f32[] constant(0) + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1]0,1} + %multiply = f32[32,24,39295] multiply(%dot, %broadcast), + sharding={devices=[1,2,1]0,1} + ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1}, + to_apply=sum, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %rhs = f32[32,39296,63,128] parameter(0) + %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1} + %constant.1 = f32[] constant(2) + %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={}, + sharding={devices=[1,2,1,1]0,1} + %add = f32[32,24,63,128] add(%broadcast, %broadcast), + sharding={devices=[1,2,1,1]0,1} + ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, + sharding={devices=[1,2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, + /*num_devices=*/2)); + VLOG(1) << module->ToString(); + // Involves loop code motion, skips pattern matching. +} + +TEST_F(SpmdPartitioningTest, ReplicatedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={replicated} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]")); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Select( + op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), + op::Rng(), op::Broadcast(op::Constant()))), + op::Shape("s32[4]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionedRng) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = s32[] parameter(0) + %lhs.copy = s32[] copy(%lhs), sharding={replicated} + %rhs = s32[] parameter(1) + %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1} + ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy), + distribution=rng_uniform, sharding={devices=[2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]")); + auto rhs = AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]")); + EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select( + op::Broadcast(op::Compare()), rhs, + op::Broadcast(op::Constant())))), + op::Shape("s32[2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input.copy, %constant, %index), + dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + EXPECT_THAT(root, + AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)), + op::Shape("s32[64,2]"))); +} + +TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[128,64] parameter(0) + %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1} + %index = s32[] parameter(1) + %constant = s32[] constant(0) + %update = s32[128,2] parameter(2) + %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1} + ROOT %dynamic-update-slice = s32[128,64] + dynamic-update-slice(%input.copy, %update.copy, %constant, %index), + sharding={devices=[2,1]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant())), + op::Shape("s32[64,64]")); + auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(), + op::Constant())), + op::Shape("s32[64,2]")); + EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(), + op::Parameter(1)), + op::Shape("s32[64,64]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughGather) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[3,5]"))); +} + +TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, + slice_sizes={1,9}, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); + auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), + op::Shape("s32[2,3]")); + auto clamp = op::Clamp(min, op::Parameter(1), max); + auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min)); + auto mask = + op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max)); + auto masked = + op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]"))); +} + +TEST_F(SpmdPartitioningTest, PassthroughScatter) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1), + op::Parameter(2)), + op::Shape("f32[2,5]"))); +} + +TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1} + %indices = s32[2,3] parameter(1), sharding={replicated} + %updates = f32[2,3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + auto offset = op::Reshape( + op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto indices = op::Subtract( + op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)), + op::Shape("f32[9,9]"))); +} + +TEST_F(SpmdPartitioningTest, TiledReverse) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}), + sharding={devices=[2,1]0,1} + ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1}, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"), + op::Reverse(op::DynamicSlice( + op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())))); +} + +TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[8,2] parameter(0), sharding={devices=[2,1]0,1} + to_shard = f32[4,2] custom-call(param), custom_call_target="SPMDFullToShardShape", sharding={replicated} + add = f32[4,2] add(to_shard, to_shard), sharding={replicated} + to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1} + ROOT mul = f32[8,2] multiply(to_full, param), sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + auto to_shard = op::Copy(op::Parameter(0)); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,2]"), + op::Multiply(op::Copy(op::Add(to_shard, to_shard)), + op::Parameter(0)))); +} + +} // namespace +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc new file mode 100644 index 00000000000..207f854cd9f --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -0,0 +1,662 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace spmd { + +bool HasReplicatedSharding(const HloSharding& sharding) { + if (sharding.IsTuple()) { + return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding); + } + return sharding.IsReplicated(); +} + +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements.push_back( + CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b)); + } + return b->AddInstruction(HloInstruction::CreateTuple(elements)); + } + + if (shape.IsToken()) { + return b->AddInstruction(HloInstruction::CreateToken()); + } + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); +} + +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + if (type == PRED) { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y)); + } else { + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + } + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))) { + return false; + } + } + } + + if (sharding.IsTileMaximal()) { + return sharding.IsReplicated(); + } + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) { + return false; + } + } + return true; +} + +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back( + MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}))); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + return sharding.TileShape(shape); +} + +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id) { + if (sharding.IsTuple()) { + std::vector subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + subshapes.push_back(MakeNonPaddedShapeForGivenPartition( + ShapeUtil::GetTupleElementShape(shape, i), + sharding.GetSubSharding(shape, {i}), partition_id)); + } + return ShapeUtil::MakeTupleShape(subshapes); + } + + auto partition_shape = shape; + std::vector tile_offset = + sharding.TileOffsetForDevice(shape, partition_id); + std::vector tile_limit = + sharding.TileLimitForDevice(shape, partition_id); + for (int64 i = 0; i < tile_offset.size(); ++i) { + if (sharding.UsesDevice(partition_id)) { + partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]); + } else { + partition_shape.set_dimensions(i, 0); + } + } + return partition_shape; +} + +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b) { + CHECK(!shape.IsTuple()); + + Array2D offset_array( + {sharding.tile_assignment().num_elements(), shape.rank()}); + offset_array.Each([&](int64 i, int64 j, int32* value) { + *value = sharding.TileOffsetForDevice(shape, i)[j]; + }); + auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(offset_array))); + std::vector offsets; + for (int64 i = 0; i < shape.rank(); ++i) { + if (sharding.tile_assignment().dim(i) == 1) { + offsets.push_back(b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); + } else { + auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1, 1}), offset_table, + {partition_id, b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(i)))}, + {1, 1})); + offsets.push_back(b->AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index))); + } + } + return offsets; +} + +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { + CHECK(!sharding.IsTileMaximal()); + auto table_shape = + ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions()); + return MakePartitionOffsets(table_shape, sharding, partition_id, b); +} + +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, HloComputation* computation) { + CHECK(b == nullptr || computation == nullptr); + if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) { + return hlo; + } + PaddingConfig padding_config; + for (int64 i = 0; i < padded_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) - + hlo->shape().dimensions(i)); + } + auto add_hlo = [&](std::unique_ptr to_add) { + if (b == nullptr) { + return computation->AddInstruction(std::move(to_add)); + } + return b->AddInstruction(std::move(to_add)); + }; + auto zero = add_hlo(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + return add_hlo( + HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config)); +} + +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return base_shape; + } + if (EvenlyPartitions(base_shape, sharding)) { + return base_shape; + } + auto shard_shape = MakePartitionedShape(base_shape, sharding); + Shape padded_base_shape = base_shape; + for (int64 i = 0; i < padded_base_shape.rank(); ++i) { + padded_base_shape.set_dimensions( + i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + return padded_base_shape; +} + +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { + auto padded_base_shape = + GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding); + if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) { + return hlo; + } + return PadToShape(hlo, padded_base_shape, b); +} + +absl::optional UniqueTiledDim(const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + return absl::nullopt; + } + int64 dim = -1; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) > 1) { + if (dim != -1) { + return absl::nullopt; + } + dim = i; + } + } + CHECK_NE(dim, -1); + return dim; +} + +MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation( + int64 multiplier, int64 offset, int64 divisor) + : multiplier_(multiplier), offset_(offset), divisor_(divisor) { + CHECK_GT(divisor_, 0); + Simplify(); +} + +OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-( + const MultiplyAddDivideOffsetCalculation& other) const { + if (divisor_ == 1 && other.divisor_ == 1) { + return OffsetCalculation(MultiplyAddDivideOffsetCalculation( + multiplier_ - other.multiplier_, offset_ - other.offset_, 1)); + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +void MultiplyAddDivideOffsetCalculation::Simplify() { + // We could simplify the calculation when multiplier is a multiple of + // divisor_. However, when offset_ is not a multiple of divisor_, we must + // make sure that offset_ and multiplier_ are both non-negative or both + // non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1. + if (divisor_ != 1 && multiplier_ % divisor_ == 0 && + (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) { + multiplier_ /= divisor_; + offset_ /= divisor_; + divisor_ = 1; + } +} + +int64 MultiplyAddDivideOffsetCalculation::Calculate(int64 shard_ordinal) const { + return (shard_ordinal * multiplier_ + offset_) / divisor_; +} + +HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate( + HloInstruction* shard_ordinal, SpmdBuilder* b) const { + auto scalar_shape = ShapeUtil::MakeShape(S32, {}); + if (multiplier_ == 0) { + return b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(offset_ / divisor_))); + } + HloInstruction* result = shard_ordinal; + if (multiplier_ != 1) { + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMultiply, shard_ordinal, + b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(multiplier_))))); + } + if (offset_ != 0) { + auto offset = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(offset_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, result, offset)); + } + if (divisor_ != 1) { + auto divisor = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(divisor_))); + result = b->AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kDivide, result, divisor)); + } + return result; +} + +int64 MultiplyAddDivideOffsetCalculation::MaxInRange( + int64 start_ordinal, int64 limit_ordinal) const { + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +OffsetCalculation& OffsetCalculation::operator=( + const OffsetCalculation& other) { + opcode_ = other.opcode_; + copy_from_ = other.copy_from_; + if (opcode_ != HloOpcode::kCopy) { + lhs_ = absl::make_unique(*other.lhs_); + rhs_ = absl::make_unique(*other.rhs_); + } + return *this; +} + +bool OffsetCalculation::IsConstant() const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.IsConstant(); + } + if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) { + return true; + } + return lhs_->IsConstant() && rhs_->IsConstant(); +} + +OffsetCalculation OffsetCalculation::operator-( + const OffsetCalculation& other) const { + if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) { + return copy_from_ - other.copy_from_; + } + return OffsetCalculation(HloOpcode::kSubtract, *this, other); +} + +bool OffsetCalculation::operator==(const OffsetCalculation& other) const { + if (opcode_ != other.opcode_) { + return false; + } + if (opcode_ == HloOpcode::kCopy) { + return copy_from_ == other.copy_from_; + } + return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_; +} + +int64 OffsetCalculation::Calculate(int64 shard_ordinal) const { + switch (opcode_) { + case HloOpcode::kCopy: + return copy_from_.Calculate(shard_ordinal); + case HloOpcode::kSubtract: + return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal); + case HloOpcode::kMultiply: + return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal); + default: + LOG(FATAL) << "Should not happen"; + } +} + +HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const { + if (opcode_ == HloOpcode::kCopy) { + return copy_from_.Calculate(shard_ordinal, b); + } + auto lhs = lhs_->Calculate(shard_ordinal, b); + auto rhs = rhs_->Calculate(shard_ordinal, b); + return b->AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs)); +} + +int64 OffsetCalculation::MaxInRange(int64 start_ordinal, + int64 limit_ordinal) const { + if (IsConstant()) { + return Calculate(start_ordinal); + } + if (opcode_ == HloOpcode::kCopy) { + return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1)); + } + int64 max = Calculate(start_ordinal); + for (int64 i = start_ordinal + 1; i < limit_ordinal; ++i) { + max = std::max(max, Calculate(i)); + } + return max; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + int64 input_shard_size = hlo->shape().dimensions(dim); + int64 shard_count = target.tile_assignment().dim(dim); + + std::vector concat_pieces; + + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + if (max_left_halo_size > input_shard_size) { + VLOG(1) << "ExchangeHalo failed: halo is beyond the left neighbor."; + return absl::nullopt; + } + if (max_left_halo_size > 0) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > 0) { + std::vector source_indices(indices.begin(), indices.end()); + source_indices[dim] -= 1; + source_target_pairs.emplace_back( + target.tile_assignment()(source_indices), device); + } + }); + auto halo_shape = hlo->shape(); + auto source_halo_slice = hlo; + if (max_left_halo_size != hlo->shape().dimensions(dim)) { + halo_shape.set_dimensions(dim, max_left_halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + halo_start_indices[dim] = + hlo->shape().dimensions(dim) - max_left_halo_size; + std::vector halo_slice_strides(halo_shape.rank(), 1); + + source_halo_slice = b->AddInstruction( + hlo->CreateSlice(halo_shape, hlo, halo_start_indices, + hlo->shape().dimensions(), halo_slice_strides)); + } + auto left_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(left_halo); + } + + concat_pieces.push_back(hlo); + + // Right halo. + int64 max_right_halo_size = + right_halo_size_function.MaxInRange(0, shard_count - 1); + if (max_right_halo_size > input_shard_size) { + VLOG(1) << "ExchangeHalo failed: halo is beyond the right neighbor."; + return absl::nullopt; + } + if (max_right_halo_size > 0) { + std::vector> source_target_pairs; + target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (indices[dim] > 0) { + std::vector target_indices(indices.begin(), indices.end()); + target_indices[dim] -= 1; + source_target_pairs.emplace_back( + device, target.tile_assignment()(target_indices)); + } + }); + auto halo_shape = hlo->shape(); + halo_shape.set_dimensions(dim, max_right_halo_size); + std::vector halo_start_indices(halo_shape.rank(), 0); + std::vector halo_slice_strides(halo_shape.rank(), 1); + + auto source_halo_slice = b->AddInstruction( + hlo->CreateSlice(halo_shape, hlo, halo_start_indices, + halo_shape.dimensions(), halo_slice_strides)); + auto right_halo = + collective_ops_creator.create_cross_partition_collective_permute( + b, source_halo_slice, source_target_pairs, (*next_channel_id)++); + concat_pieces.push_back(right_halo); + } + + auto concat = hlo; + // Concat with halos/padding. + if (concat_pieces.size() > 1) { + auto concat_shape = hlo->shape(); + int64 concat_dim_size = 0; + for (auto piece : concat_pieces) { + concat_dim_size += piece->shape().dimensions(dim); + } + concat_shape.set_dimensions(dim, concat_dim_size); + concat = b->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim)); + } + + return concat; +} + +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b) { + CHECK(left_halo_size_functions.size() == hlo->shape().rank()); + CHECK(right_halo_size_functions.size() == hlo->shape().rank()); + + HloInstruction* visiting_hlo = hlo; + for (int dim = 0; dim < hlo->shape().rank(); ++dim) { + auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim], + right_halo_size_functions[dim], dim, target, + collective_ops_creator, next_channel_id, b); + if (!concat) { + return absl::nullopt; + } + visiting_hlo = *concat; + } + return visiting_hlo; +} + +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) { + auto halo_exchange_result = + ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim, + target, collective_ops_creator, next_channel_id, b); + if (!halo_exchange_result) { + return absl::nullopt; + } + auto concat = *halo_exchange_result; + int64 shard_count = target.tile_assignment().dim(dim); + int64 max_left_halo_size = left_halo_size_function.MaxInRange(1, shard_count); + + // Now we determine if we need extra padding after the concat. + // + // The max of halo size or the first shard's explicit left padding. + int64 max_left_halo_or_padding_size = + std::max(std::max(int64{0}, max_left_halo_size), + explicit_left_padding_on_full_shape); + // The calculation that returns the dynamic slice index for a shard on the + // padded concat, which is the difference between + // max_left_halo_or_padding_size and its left halo size. + auto start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, max_left_halo_or_padding_size, 1)) - + left_halo_size_function; + + // See if we need to pad the concat before dynamic slice. + int64 extra_left_padding = + std::max(int64{0}, max_left_halo_or_padding_size - + std::max(int64{0}, max_left_halo_size)); + int64 extra_right_padding = + start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) + + shard_size_with_halo - concat->shape().dimensions(dim) - + extra_left_padding; + extra_right_padding = std::max(int64{0}, extra_right_padding); + if (extra_left_padding > 0 || extra_right_padding > 0) { + PaddingConfig padding_config; + auto padded_concat_shape = concat->shape(); + for (int64 i = 0; i < base_shape.rank(); ++i) { + auto padding_config_dim = padding_config.add_dimensions(); + padding_config_dim->set_interior_padding(0); + padding_config_dim->set_edge_padding_low(0); + padding_config_dim->set_edge_padding_high(0); + if (i != dim) { + continue; + } + padding_config_dim->set_edge_padding_low(extra_left_padding); + padding_config_dim->set_edge_padding_high(extra_right_padding); + padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) + + extra_left_padding + + extra_right_padding); + } + concat = b->AddInstruction(HloInstruction::CreatePad( + padded_concat_shape, concat, pad_value, padding_config)); + } + + auto valid_slice = concat; + if (shard_size_with_halo != concat->shape().dimensions(dim)) { + // Concat is bigger than the shard shape, so we need a dynamic slice. + CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(dim, shard_size_with_halo); + + if (left_halo_size_function.IsConstant() && + left_halo_size_function.Calculate(0) == + explicit_left_padding_on_full_shape) { + std::vector start_indices(slice_shape.rank(), 0); + std::vector strides(slice_shape.rank(), 1); + valid_slice = b->AddInstruction( + HloInstruction::CreateSlice(slice_shape, concat, start_indices, + slice_shape.dimensions(), strides)); + } else { + auto zero = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector slice_offsets(base_shape.rank(), zero); + slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( + partition_ordinal, b); + valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); + } + } + + if (!mask_invalid_region) { + return valid_slice; + } + + int64 total_right_padding = padded_full_shape_size - + base_shape.dimensions(dim) - + explicit_left_padding_on_full_shape; + // Mask off garbage data due to uneven partition or low/high padding. + if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) { + auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32); + auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); + auto broadcast_start_index_in_padded_shape = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, offset_on_padded_shape, {})); + auto index_in_padded_shape = b->AddInstruction( + HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota, + broadcast_start_index_in_padded_shape)); + auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); + std::vector predicates; + if (explicit_left_padding_on_full_shape > 0) { + auto valid_index_start = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_start, + ComparisonDirection::kGe))); + } + if (total_right_padding > 0) { + auto valid_index_limit = + b->AddInstruction(HloInstruction::CreateBroadcast( + index_shape, + b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + base_shape.dimensions(dim) + + explicit_left_padding_on_full_shape))), + {})); + predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare( + mask_shape, index_in_padded_shape, valid_index_limit, + ComparisonDirection::kLt))); + } + CHECK(!predicates.empty()); + auto is_valid = + predicates.size() == 2 + ? b->AddInstruction(HloInstruction::CreateBinary( + mask_shape, HloOpcode::kAnd, predicates[0], predicates[1])) + : predicates[0]; + auto masking_value = b->AddInstruction( + HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {})); + valid_slice = b->AddInstruction( + HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect, + is_valid, valid_slice, masking_value)); + } + return valid_slice; +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h new file mode 100644 index 00000000000..f96b23d7073 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -0,0 +1,229 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" + +namespace xla { +namespace spmd { + +// Returns true if the given sharding contains any replicated sharding. +bool HasReplicatedSharding(const HloSharding& sharding); + +// Creates zero value instructions of the given shape. +HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); + +template +HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, + SpmdBuilder* b) { + auto literal = LiteralUtil::CreateR0(value) + .ConvertToShape(ShapeUtil::MakeShape(type, {})) + .ValueOrDie(); + return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal))); +} + +// Create a binary add computation of the given type and add to the module. +HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module); + +// Returns true if the shape can be evenly partitioned for the given sharding. +// All tile sharded dimensions should be evenly divisible and there should be no +// single-device sharding. Replicate sharding is considered even partition. +bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding); + +// Returns the shard shape of the given shape when it is partitioned for the +// target sharding. +Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding); + +// Returns the shard shape for a partition without padding due to uneven +// sharding. +Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, + const HloSharding& sharding, + int64 partition_id); + +// Generates the HLO instructions that represent the dimension offsets on any +// device. The size of the returned vector is the rank of the given shape. +std::vector MakePartitionOffsets(const Shape& shape, + const HloSharding& sharding, + HloInstruction* partition_id, + SpmdBuilder* b); + +// Returns the offsets of the partition in the tile assignment. +std::vector MakeTiledPartitionOrdinals( + const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b); + +// Pads hlo to the desired shape using high padding. Either a builder or a +// computation needs to be supplied, but not both. +HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, + SpmdBuilder* b, + HloComputation* computation = nullptr); + +// Returns the padded shape when combining all partitions. +Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape, + const HloSharding& sharding); + +// Pads the HLO (with base shape) for uneven tiled partition to make it evenly +// partitionable. +HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( + HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b); + +// Returns the index of the unique tile dimension. Returns absl::nullopt if the +// given sharding is not tiled or tiled along multiple dimensions. +absl::optional UniqueTiledDim(const HloSharding& sharding); + +// Utilities for symbolic offset calculation and halo exchange. +class OffsetCalculation; + +// Represents a calculation over integers: +// (shard_ordinal * multiplier + offset) / divisor +class MultiplyAddDivideOffsetCalculation { + public: + MultiplyAddDivideOffsetCalculation() + : multiplier_(0), offset_(0), divisor_(1) {} + MultiplyAddDivideOffsetCalculation(int64 multiplier, int64 offset, + int64 divisor); + + OffsetCalculation operator-( + const MultiplyAddDivideOffsetCalculation& other) const; + + bool operator==(const MultiplyAddDivideOffsetCalculation& other) const { + return multiplier_ == other.multiplier_ && offset_ == other.offset_ && + divisor_ == other.divisor_; + } + + bool IsConstant() const { return multiplier_ == 0; } + void Simplify(); + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + int64 multiplier_; + int64 offset_; + int64 divisor_; +}; + +// Represents a calculation over integers based on results of other calculations +// defined by an opcode. If the opcode is kCopy, it simply wraps an +// MultiplyAddDivideOffsetCalculation. +class OffsetCalculation { + public: + OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {} + explicit OffsetCalculation( + const MultiplyAddDivideOffsetCalculation& copy_from) + : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {} + OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; } + OffsetCalculation(HloOpcode opcode, + const MultiplyAddDivideOffsetCalculation& lhs, + const MultiplyAddDivideOffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs, + const OffsetCalculation& rhs) + : opcode_(opcode), + lhs_(absl::make_unique(lhs)), + rhs_(absl::make_unique(rhs)) {} + + OffsetCalculation& operator=(const OffsetCalculation& other); + + // Returns whether the calculation returns the same value for all shards. This + // is conservative and could return false even if it is actually constant. + bool IsConstant() const; + + OffsetCalculation operator-(const OffsetCalculation& other) const; + bool operator==(const OffsetCalculation& other) const; + int64 Calculate(int64 shard_ordinal) const; + HloInstruction* Calculate(HloInstruction* shard_ordinal, + SpmdBuilder* b) const; + + // Returns the maximum result for shard ordinals in the range + // [start_ordinal, limit_ordinal). + int64 MaxInRange(int64 start_ordinal, int64 limit_ordinal) const; + + private: + HloOpcode opcode_; + std::unique_ptr lhs_; + std::unique_ptr rhs_; + MultiplyAddDivideOffsetCalculation copy_from_; +}; + +// Performs halo exchange on the given dimension based on the provided +// left/right halo size functions. Returns nullopt if the halo is beyond the +// direct neighbor of the shard. +absl::optional ExchangeHalo( + HloInstruction* hlo, const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, int64 dim, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the +// dimensions fails to exchange halo (halo is beyond the neighbor shard). +absl::optional ExchangeHalo( + HloInstruction* hlo, + std::vector left_halo_size_functions, + std::vector right_halo_size_functions, + const HloSharding& target, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b); + +// Exchanges halos and performs pad/dynamic-slice on the concatenated data such +// that the result starts with the first needed element on each shard. It also +// masks off invalid data due to padding. +// Arguments: +// hlo: the HLO op before halo exchange +// explicit_left_padding_on_full_shape: the amount of left padding to be added +// explicitly by this function on the base shape before partitioning. Without +// base dilation, this is usually set to the window's padding_low so that the +// sharded op do not need to add padding_low on the window; however, with base +// dilation, this could only be set to a custom size. +// padded_full_shape_size: the size of the padded full shape on the given +// dimension, which includes explicit_left_padding_on_full_shape and required +// right padding to make the shape evenly shardable. +// shard_size_with_halo: the shard size on the dimension after halo exchange. +// If different shards have different sizes, use the maximum size. +// offset_on_padded_shape: the offset HLO (S32) that represents the start of +// each shard on the padded full shape. +// pad_value: the padding value used on the full shape. +absl::optional ExchangeHaloAndGetValidData( + HloInstruction* hlo, const Shape& base_shape, + const OffsetCalculation& left_halo_size_function, + const OffsetCalculation& right_halo_size_function, + int64 explicit_left_padding_on_full_shape, int64 padded_full_shape_size, + int64 shard_size_with_halo, int64 dim, const HloSharding& target, + HloInstruction* offset_on_padded_shape, HloInstruction* pad_value, + HloInstruction* partition_ordinal, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true); + +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 2d33184b7d0..1111811d3a3 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -300,7 +300,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { - VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + VLOG(2) << "HLO module before WhileLoopInvariantCodeMotion:"; XLA_VLOG_LINES(2, module->ToString()); bool changed = false; @@ -332,10 +332,10 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { } if (changed) { - VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + VLOG(2) << "HLO module after WhileLoopInvariantCodeMotion:"; XLA_VLOG_LINES(2, module->ToString()); } else { - VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + VLOG(2) << "HLO module unchanged after WhileLoopInvariantCodeMotion"; } return changed; diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index 2793ddfc1ae..dfaac677724 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -63,6 +63,8 @@ class Shape { // shapes are traversed recursively. bool is_static() const; + bool is_dynamic() const { return !is_static(); } + // Returns true if the given dimension is dynamically-sized. bool is_dynamic_dimension(int dimension) const { return dynamic_dimensions_.at(dimension); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 22ee5a16a30..52cbb8f95ac 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" @@ -150,6 +151,19 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualStructure(const Shape& lhs, + const Shape& rhs) { + bool equal = true; + ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(rhs, index); + }); + ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(lhs, index); + }); + + return equal; +} + /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -261,6 +275,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ValidateShape(*shape); } +/* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) { + Shape result = original; + result.clear_dynamic_dimensions(); + return result; +} + /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); @@ -626,8 +646,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); } else if (shape.IsArray()) { - int64 byte_size = ByteSizeOfElements(shape); - return byte_size; + return ByteSizeOfElements(shape); } else if (shape.element_type() == TOKEN) { return 0; } else if (shape.element_type() == OPAQUE_TYPE) { @@ -1441,6 +1460,19 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( return shape; } +/* static */ bool ShapeUtil::DynamicShapeIsCompatible( + const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { + if (dynamic_shape.rank() != bounded_shape.rank()) { + return false; + } + for (int64 i = 0; i < dynamic_shape.rank(); ++i) { + if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { + return false; + } + } + return true; +} + /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { CHECK(shape.IsArray()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 7e05e17865d..dde56587482 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -298,6 +298,16 @@ class ShapeUtil { // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Two shapes have same structure if all subshape indices of lhs are presented + // on rhs and vice versa. + // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to + // (S32, (F32[3], S32[2])) as their structures are both (,(,)) + // + // In contrast, (F32, (F32, F32)) is structurally different from + // ((F32, F32), F32) as the former has structure (,(,)) while the latter has + // ((,),) + static bool EqualStructure(const Shape& lhs, const Shape& rhs); + // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -339,6 +349,9 @@ class ShapeUtil { // element type changed to type. static Shape ChangeElementType(const Shape& original, PrimitiveType type); + // Retursn a shape with same dimensions but with all dimensions set to static. + static Shape MakeStaticShape(const Shape& original); + // Creates a tuple shape from a slice of element shapes within the tuple. static Shape MakeTupleShape(absl::Span shapes); @@ -643,12 +656,16 @@ class ShapeUtil { static Shape FilterDimensions(const std::function& p, Shape shape); - // Iterates through all the shape indexes, in minor to major order, starting - // from the base indexes, incrementing by the incr steps, up to count - // (index[i] < base[i] + count[i]), and calls the visitor_function with the - // current index. - // The visitor_function visitor function should return true if it wants to - // continue, or false otherwise. + // Returns true if `dynamic_shape` has dimensions that are less-equal to the + // "bounded_shape". + static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, + const xla::Shape& bounded_shape); + + // Iterates through all the shape indexes, in minor to major order, + // starting from the base indexes, incrementing by the incr steps, up to + // count (index[i] < base[i] + count[i]), and calls the visitor_function + // with the current index. The visitor_function visitor function should + // return true if it wants to continue, or false otherwise. // // visitor_function must be a callable of type // StatusOr(absl::Span) or compatible. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index c453c5fefa0..c8a242c156a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1104,6 +1104,7 @@ xla_test( shard_count = 40, tags = [ "no_rocm", + "nozapfhahn", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1500,6 +1501,7 @@ xla_test( srcs = ["select_and_scatter_test.cc"], tags = [ "no_rocm", + "nozapfhahn", "optonly", ], deps = [ diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc index 0ed79fa0ad8..44e1b7b5a6f 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test_f32_or_smaller.cc @@ -352,6 +352,17 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, { Run(Sqrt, std::sqrt, error_spec_gen); }) +UNARY_TEST_FLOAT_32_BITS_OR_LESS(Cbrt, { + if (platform_ == "Host" || platform_ == "CUDA") { + ErrorSpecGen error_spec_gen = +[](NativeT x) { + return ErrorSpec{0.01, 0.01}; + }; + Run(Cbrt, std::cbrt, error_spec_gen); + } else { + Run(Cbrt, std::cbrt); + } +}) + // TODO(jlebar): Test trig functions over complex inputs. XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) { // Error inherited from Log, which our implementation of Acosh uses. diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 5a482305513..d575bbb1f3e 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -863,7 +863,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Starts = iteration * 2; auto starts = Mul(iteration, ConstantR0(&builder, 2)); // UpdateSlice. - auto out1 = DynamicUpdateSlice(input, update, starts); + auto out1 = DynamicUpdateSlice(input, update, {starts}); Tuple(&builder, {out0, out1}); body = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 826876ed9cb..f4b08f454b9 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -272,7 +272,15 @@ message DebugOptions { // True if TraceMe annotations are enabled for XLA:CPU. bool xla_cpu_enable_xprof_traceme = 137; - // Next id: 138 + // It is usually preferable to not fallback to the driver; it can consume more + // memory, or have bugs. + bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138; + + // It is usually preferable to not fallback to the driver; it can consume more + // memory, or have bugs. + bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_error = 139; + + // Next id: 140 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -325,6 +333,10 @@ message ExecutionOptions { // Used to identify a set of programs that should be launch together. int32 launch_id = 10; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning = 11; } message GetDeviceHandlesRequest { diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index d1445144b76..332c8ff9a14 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -58,6 +58,7 @@ cc_library( "xrt_state.h", "xrt_util.h", ], + visibility = ["//visibility:public"], deps = [ ":xrt_proto_cc", "//tensorflow/compiler/jit:xla_device", diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 309b4f4c85a..494ba29e981 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":xrt_state_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -59,6 +60,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xrt:xrt_compile_ops_op_lib", "//tensorflow/compiler/xrt:xrt_execute_op_op_lib", "//tensorflow/compiler/xrt:xrt_proto_cc", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 83b1b4c8a05..ba6e6a093d6 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -51,6 +51,46 @@ namespace tensorflow { namespace { +Status GenerateXlaDeviceAssignment( + const xrt::DeviceAssignment& xrt_device_assignment, int num_replicas, + int num_cores_per_replica, xla::DeviceAssignment* device_assignment) { + if (num_cores_per_replica != + xrt_device_assignment.computation_devices_size()) { + return errors::InvalidArgument( + "Device assignment does not have the correct number of " + "computation_devices: num_cores_per_replica=", + num_cores_per_replica, " computation_devices=", + xrt_device_assignment.computation_devices_size()); + } + for (int64 c = 0; c < xrt_device_assignment.computation_devices_size(); ++c) { + const auto& computation_devices = + xrt_device_assignment.computation_devices(c); + if (num_replicas != computation_devices.replica_devices_size()) { + return errors::InvalidArgument( + "Device assignment does not have the correct number of " + "replica_device_ids: num_replicas=", + num_replicas, + " replica_devices=", computation_devices.replica_devices_size()); + } + for (int64 r = 0; r < computation_devices.replica_devices_size(); ++r) { + const auto& coords = computation_devices.replica_devices(r); + if (coords.value_size() != 4) { + return errors::InvalidArgument( + "Device assignment mesh coordinates must have 4 entries, got ", + coords.value_size()); + } + for (int n = 0; n < 3; ++n) { + if (coords.value(n) != 0) { + return errors::InvalidArgument("Mesh coordinate at index ", n, + " must be 0, got ", coords.value(n)); + } + } + (*device_assignment)(r, c) = coords.value(3); + } + } + return Status::OK(); +} + class XRTCompileOp : public OpKernel { public: explicit XRTCompileOp(OpKernelConstruction* ctx); @@ -83,14 +123,13 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, const xrt::XLAComputation& computation_proto, std::unique_ptr* program) { const xrt::XLAComputationConfig& config = computation_proto.config(); + // Sanity checks for options not yet supported. + int num_cores_per_replica = std::max(config.num_cores_per_replica(), 1); + TF_RET_CHECK(num_cores_per_replica == 1); + TF_RET_CHECK(config.per_core_program_shape_size() == 0); // The default config value is 0; treat it as 1 for convenience. int num_replicas = config.num_replicas() ? config.num_replicas() : 1; - TF_RET_CHECK(num_replicas == 1); - int num_cores_per_replica = - config.num_cores_per_replica() ? config.num_cores_per_replica() : 1; - TF_RET_CHECK(num_cores_per_replica == 1); - TF_RET_CHECK(config.per_core_program_shape_size() == 0); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. @@ -119,13 +158,22 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, argument_layout_ptrs[i] = &argument_layouts[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client->default_device_ordinal()); + build_options.set_device_ordinal(device_ref.device_ordinal()); + build_options.set_num_replicas(num_replicas); build_options.set_result_layout(xla::Shape(config.program_shape().result())); build_options.set_device_allocator(device_ref.backend()->memory_allocator()); if (config.has_debug_options()) { *build_options.mutable_debug_options() = BuildXlaDebugOptions(config.debug_options()); } + if (config.has_device_assignment()) { + xla::DeviceAssignment device_assignment(num_replicas, + num_cores_per_replica); + TF_RETURN_IF_ERROR( + GenerateXlaDeviceAssignment(config.device_assignment(), num_replicas, + num_cores_per_replica, &device_assignment)); + build_options.set_device_assignment(device_assignment); + } VLOG(1) << "Building executable"; TF_ASSIGN_OR_RETURN( @@ -158,7 +206,8 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key)); // Process-wide cache of XLA executables. - auto cache_or = GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); OP_REQUIRES_OK(ctx, cache_or.status()); auto cache = cache_or.ConsumeValueOrDie(); @@ -211,15 +260,11 @@ void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XRTReleaseCompilationRefOp::Compute"; auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell()); - ResourceMgr* rm; - OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm)); - // Process-wide cache of XLA executables. - XRTCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->Lookup( - rm->default_container(), - kXRTCompilationCacheResourceName, &cache)); - core::ScopedUnref cache_unref(cache); + auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + ctx, /*max_number_of_entries=*/0); + OP_REQUIRES_OK(ctx, cache_or.status()); + auto cache = cache_or.ConsumeValueOrDie(); const Tensor& keys_tensor = ctx->input(0); auto flat_keys = keys_tensor.flat(); diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 45c8e1ad59a..2fc599e42df 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -37,7 +39,11 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/monitoring/timed.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -145,31 +151,301 @@ xla::StatusOr GetChainedOpInputs( return std::move(input_buffers); } +// Given a shape, returns a byte array representing the shape metadata of the +// shape. The shape metadata contains dimensions sizes stored as contiguous S32. +std::vector PrepareMetadata(const xla::Shape& shape) { + DCHECK(shape.is_static()); + DCHECK(shape.IsArray()); + // Each dimension size is stored as a S32. + std::vector result(shape.dimensions_size()); + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + result[i] = shape.dimensions(i); + } + return result; +} + +// Given a buffer with dynamic shape, update buffer metadata at the correct +// offset starting from that buffer. +// +// +-----------+ +// |Payload | +// +-----------+ +// | Padding | +// +-----------+ +// |dim_size_0 | (each dim_size is a S32): +// +-----------+ +// |dim_size_1 | +// +-----------+ +// .......... +// +-----------+ +// +// Size of payload = ByteSizeOf(runtime_shape) +// Size of payload + padding = ByteSizeOf(compile_time_shape_static) +// Size of payload + padding + metadata = ByteSizeOf(compile_time_shape) +Status UpdateMetadata(se::Stream* stream, se::DeviceMemory* buffer, + const xla::Shape& compile_time_shape, + const xla::Shape& runtime_shape) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape compile_time_shape_static = + xla::ShapeUtil::MakeStaticShape(compile_time_shape); + uint64 offset = shape_size_fn(compile_time_shape_static); + uint64 metadata_size = shape_size_fn(compile_time_shape) - offset; + auto metadata_buffer = + stream->parent()->GetSubBuffer(buffer, offset, metadata_size); + + auto metadata_literal = std::make_shared( + xla::LiteralUtil::CreateR1(PrepareMetadata(runtime_shape))); + TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync( + stream, *metadata_literal, metadata_buffer)); + // Retain the literal until the end of the transfer. + stream->ThenDoHostCallback([metadata_literal]() { return Status::OK(); }); + return Status::OK(); +} + +// Given a static input buffer, convert it to dynamic form by expanding it to +// the bounded size and attaching a metadata filled with dimension sizes. +// +// From: +// +--------+ +// |Payload | +// +--------+ +// +// To: +// +// +--------+ +// |Payload | +// +--------+ +// | Padding| +// +--------+ +// |Metadata| +// +--------+ +// +// As we can't expand the size of an existing memory allocation, a reallocation +// is required. A list of new allocations are returned after this function. The +// caller is reponsible for maintaining those allocations. +xla::StatusOr> UpdateDynamicInputs( + se::Stream* stream, se::DeviceMemoryAllocator* allocator, + std::vector runtime_inputs, + const std::vector& compile_time_shapes) { + std::vector new_allocations; + TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size()); + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + for (int64 i = 0; i < compile_time_shapes.size(); i++) { + const xla::Shape& compile_time_shape = compile_time_shapes[i].shape(); + if (compile_time_shape.is_static()) { + continue; + } + auto* runtime_input = runtime_inputs[i]; + + bool element_modified = false; + TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( + compile_time_shape, + [&](const xla::Shape& compile_time_shape, + const xla::ShapeIndex& index) -> Status { + if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) { + return Status::OK(); + } + const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape( + runtime_input->on_device_shape(), index); + TF_RET_CHECK(!runtime_shape.IsTuple()); + TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible( + runtime_shape, compile_time_shape)); + se::DeviceMemoryBase* static_input = + runtime_input->buffers().mutable_element(index); + TF_ASSIGN_OR_RETURN( + auto dynamic_input, + allocator->Allocate(stream->parent()->device_ordinal(), + shape_size_fn(compile_time_shape))); + new_allocations.emplace_back(std::move(dynamic_input)); + se::DeviceMemory* dynamic_input_base = + new_allocations.back().ptr(); + // Send the original data to the new location. + stream->ThenMemcpyD2D(dynamic_input_base, *static_input, + static_input->size()); + TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base, + compile_time_shape, runtime_shape)); + // Modify the memory location in the input shape tree to point to the + // new input. + runtime_input->set_buffer(*dynamic_input_base, index); + element_modified = true; + return Status::OK(); + })); + if (element_modified) { + runtime_input->set_shapes(compile_time_shape, compile_time_shape); + // The input location has been modified, need to fix tuple table to + // point to the correct address. + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR( + transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input)); + } + } + return std::move(new_allocations); +} + +xla::StatusOr ReadMetadataLiteral( + se::Stream* stream, se::DeviceMemoryBase* buffer, + const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape buffer_shape_static = + xla::ShapeUtil::MakeStaticShape(buffer_shape); + const int64 offset = shape_size_fn(buffer_shape_static); + int64 metadata_size = shape_size_fn(buffer_shape) - offset; + TF_RET_CHECK(metadata_size != 0); + auto buffer_8 = se::DeviceMemory(*buffer); + auto metadata_buffer = + stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + return transfer_manager->TransferArrayFromDevice( + stream, + xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}), + metadata_buffer); +} + +// For each subshape in the result buffer that's dynamic, read the dynamic +// dimension sizes from the metadata, and update output shapes. The result shape +// is a static and concrete shape. +xla::Status UpdateDynamicOutputs(se::Stream* stream, + xla::ShapedBuffer* shaped_buffer, + xla::Shape* output_host_shape, + xla::Shape* output_device_shape) { + DCHECK(output_device_shape->is_dynamic()); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const xla::Shape& buffer_shape = + xla::ShapeUtil::GetSubshape(*output_device_shape, index); + if (buffer_shape.IsTuple()) { + return Status::OK(); + } + xla::Shape& host_shape = + *xla::ShapeUtil::GetMutableSubshape(output_host_shape, index); + xla::Shape& device_shape = + *xla::ShapeUtil::GetMutableSubshape(output_device_shape, index); + if (device_shape.is_static()) { + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(auto metadata, + ReadMetadataLiteral(stream, buffer, buffer_shape, + transfer_manager)); + // Update shape size from metadata. + for (int64 i = 0; i < metadata.element_count(); ++i) { + host_shape.mutable_dimensions()[i] = metadata.Get({i}); + device_shape.mutable_dimensions()[i] = metadata.Get({i}); + } + return Status::OK(); + })); + output_host_shape->clear_dynamic_dimensions(); + output_device_shape->clear_dynamic_dimensions(); + return Status::OK(); +} + +// Create output tuple from run_result. +xla::StatusOr> CreateOutputTuple( + se::Stream* stream, xla::ScopedShapedBuffer run_result, + xla::Backend* backend, int device_ordinal) { + XRTTupleAllocation* output_tuple; + xla::ShapedBuffer shaped_buffer = run_result.release(); + if (shaped_buffer.on_device_shape().is_dynamic()) { + // Update dynamic shapes from output buffer, and create a XRT tensor with + // dimension sizes read from metadata. + xla::Shape output_host_shape = shaped_buffer.on_host_shape(); + xla::Shape output_device_shape = shaped_buffer.on_device_shape(); + TF_RETURN_IF_ERROR(UpdateDynamicOutputs( + stream, &shaped_buffer, &output_host_shape, &output_device_shape)); + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, output_host_shape, output_device_shape, backend, + device_ordinal, &output_tuple)); + } else { + // Fast-path: Don't copy shapes of output buffer. + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, backend, device_ordinal, &output_tuple)); + } + return RefPtr(output_tuple); +} + xla::StatusOr> RunExecutable( OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(device_ref->backend()->memory_allocator()); run_options.set_intra_op_thread_pool(&context->eigen_cpu_device()); run_options.set_rng_seed(rng_seed); + if (config.run_id() != 0) { + run_options.set_run_id(xla::RunId(config.run_id())); + } + if (executable->executable() + ->module_config() + .has_static_device_assignment()) { + run_options.set_device_assignment( + &executable->executable()->module_config().static_device_assignment()); + } + xla::GpuExecutableRunOptions gpu_options; + std::vector gpu_global_ids; + if (config.local_replica_mapping_size() > 0) { + gpu_global_ids.reserve(config.local_replica_mapping_size()); + for (auto& gid : config.local_replica_mapping()) { + gpu_global_ids.emplace_back(xla::GlobalDeviceId(gid)); + } + gpu_options.set_gpu_global_device_ids(gpu_global_ids); + } + std::shared_ptr nccl_factory = GetNcclUniqueIdFactory(); + if (nccl_factory != nullptr) { + auto uid_callback = + [&](const xla::NcclCliqueKey& key) -> xla::StatusOr { + std::vector replicas; + for (auto& device : key.devices()) { + replicas.push_back(device.value()); + } + return nccl_factory->GetUniqueId(replicas); + }; + gpu_options.set_nccl_unique_id_callback(uid_callback); + } + run_options.set_gpu_executable_run_options(&gpu_options); Env* env = Env::Default(); auto start_time = env->NowMicros(); + const std::vector& shape_layouts = + executable->executable() + ->module_config() + .entry_computation_layout() + .parameter_layouts(); + TF_ASSIGN_OR_RETURN(auto new_allocations, + UpdateDynamicInputs(stream, run_options.allocator(), + input_buffers.input_pointers, + shape_layouts)); + auto new_allocations_ptr = + std::make_shared>( + std::move(new_allocations)); TF_ASSIGN_OR_RETURN( xla::ScopedShapedBuffer run_result, executable->Run(input_buffers.input_pointers, run_options)); + // Retain the new allocation for input memory until the end of execution. + stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); }); + auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - auto shaped_buffer = run_result.release(); - XRTTupleAllocation* output_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, device_ref->backend(), device_ref->device_ordinal(), - &output_tuple)); - RefPtr output_tuple_ptr(output_tuple); + TF_ASSIGN_OR_RETURN( + RefPtr output_tuple_ptr, + CreateOutputTuple(stream, std::move(run_result), device_ref->backend(), + device_ref->device_ordinal())); // The ScopedShapedBuffer returned by the executable Run() API, in case of // input/output buffer aliasing, might have holes in it, which need to be @@ -182,7 +458,7 @@ xla::StatusOr> RunExecutable( const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size()); return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias - ? output_tuple->AliasBufferFrom( + ? output_tuple_ptr->AliasBufferFrom( *input_buffers.input_tuples[alias.parameter_number], alias.parameter_index, output_index) : Status::OK(); @@ -196,10 +472,11 @@ xla::StatusOr> ExecuteComputation( OpKernelContext* context, XRTMemoryManager* memory_manager, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { auto runfn = [&]() { return RunExecutable(context, device_ref, executable, input_buffers, stream, - rng_seed); + rng_seed, config); }; // We pass zero as requested_free_size as there is no simple way to get the @@ -215,13 +492,15 @@ xla::StatusOr> ExecuteComputation( XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const std::vector& input_coords, bool release_inputs, - se::Stream* stream, int rng_seed) { + se::Stream* stream, int rng_seed, + const xrt::CommonExecutionConfig& config) { XRTMemoryManager::WorkingSet working_set(memory_manager); TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, GetInputBuffers(&working_set, device_ref->backend(), input_coords, release_inputs)); return ExecuteComputation(context, memory_manager.get(), device_ref, - executable, input_buffers, stream, rng_seed); + executable, input_buffers, stream, rng_seed, + config); } // XRTExecuteOp @@ -270,8 +549,9 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { bool release_inputs = config_proto.release_input_handles(); bool release_compilation = config_proto.release_compilation_handle(); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -302,7 +582,8 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_ASSIGN_OR_RETURN( RefPtr output_tuple, ExecuteComputation(context, memory_manager, &device_ref, executable, - input_coords, release_inputs, stream, rng_seed)); + input_coords, release_inputs, stream, rng_seed, + config_proto.common_config())); return CreateExecuteOutput(context, memory_manager.get(), std::move(output_tuple), @@ -351,8 +632,9 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { xrt::XRTChainedExecuteConfig config; TF_RET_CHECK(ParseFromTString(execution_config.scalar()(), &config)); - TF_ASSIGN_OR_RETURN( - auto cache, GetOrCreateCompilationCache(rm, /*max_number_of_entries=*/0)); + TF_ASSIGN_OR_RETURN(auto cache, + XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + context, /*max_number_of_entries=*/0)); // We are guaranteed that the underlying device object won't be deleted out // from under us, while the ScopedRef is live. class XRTGenericDeviceAccessor::ScopedRef device_ref; @@ -379,7 +661,8 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { xla::LocalExecutable* executable = entry->get().get_executable(); return ExecuteComputation(context, memory_manager.get(), &device_ref, - executable, input_buffers, stream, rng_seed); + executable, input_buffers, stream, rng_seed, + config.common_config()); }; return ExecuteChained(context, memory_manager, device_ref.backend(), diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 243289c8821..fbf9dfd0a17 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -49,6 +49,67 @@ limitations under the License. namespace tensorflow { namespace { +xla::XlaComputation ReturnDynamicR1() { + xla::XlaBuilder builder("ReturnDynamicR1"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + return builder.Build(pad_sum).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); + auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1"); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +xla::XlaComputation ReturnDynamicR1Tuple() { + xla::XlaBuilder builder("ReturnDynamicR1Tuple"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + auto one = xla::One(&builder, xla::S32); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0); + auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub}); + return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1Tuple() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + xla::Shape tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + xla::Shape nest_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + auto p = xla::Parameter(&builder, 0, tuple_shape, "P0"); + auto p0 = xla::GetTupleElement(p, 0); + auto p1 = xla::GetTupleElement(p, 1); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +template +xla::LiteralProto CreateR0(T v) { + auto array = xla::LiteralUtil::CreateR0(v); + return array.ToProto(); +} + class XrtClientSession : public ClientSession { public: explicit XrtClientSession(const Scope& scope) : ClientSession(scope) { @@ -61,6 +122,11 @@ class XrtClientSession : public ClientSession { string* xla_test_device_ptr; // initial value set in main() string* xla_platform_ptr; // initial value set in main() +bool SupportDynamicShapes() { + // TODO(jackcao): Support dynamic shapes on XLA GPU. + return *xla_test_device_ptr != "XLA_GPU"; +} + string DeviceFromFlag() { string xla_test_device = *xla_test_device_ptr; return absl::StrCat("/device:", xla_test_device, ":0"); @@ -1035,6 +1101,239 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_EQ(program_shape.parameters_size(), 2); } +TEST(RawApiTest, DynamicR1Test) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, DynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape( + {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape}) + .ToProto(); + StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected0 = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + auto expected1 = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f, 0.0f}); + auto expected2 = xla::LiteralUtil::CreateR1({0.0f, 3.0f, 1.0f}); + auto expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLATupleNode tuple_desc; + auto subdesc_10 = tuple_desc.add_tuples(); + auto subdesc_11 = tuple_desc.add_tuples(); + subdesc_10->set_input_index(0); + subdesc_10->set_release_input_handle(true); + subdesc_11->set_input_index(1); + subdesc_11->set_release_input_handle(true); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + xla::Shape dyn_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape}); + *shapes->add_parameters() = dyn_tuple_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + + auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"), + tuple_desc.SerializeAsString()); + auto t0_handle = ops::XRTMakeTuple( + root, tuple_0, + {static_cast(p0_handle), static_cast(p1_handle)}); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {static_cast(t0_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1Test) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto allocate_op_0 = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto allocate_op_1 = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(allocate_op_0), Output(allocate_op_1)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f}); diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 47b7cda2760..9a351732c4b 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -111,6 +111,17 @@ message XLATupleNode { repeated XLATupleNode tuples = 3; } +message CommonExecutionConfig { + // The replica index this execute is driving. + int32 replica_id = 1; + // Mapping local device ordinals to global replica IDs. + // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID + repeated int32 local_replica_mapping = 2; + // The execution run ID used to correlate different XRT execute operations + // happeining in parallel from different threads. + int64 run_id = 3; +} + // Options for an XLA execution. message XRTExecutionConfig { // Local device to run on. This is present because the execute Op @@ -133,6 +144,9 @@ message XRTExecutionConfig { // a single tuple allocation the execution will return a vector of // allocations, one for each of the first-level elements of the result tuple. bool return_exploded_tuple = 7; + reserved 8; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 9; } message XRTChainedExecuteConfig { @@ -143,6 +157,9 @@ message XRTChainedExecuteConfig { // Optional key to disambiguate between executions. This is only needed if // multiple host send/recvs may be outstanding concurrently with executions. string execution_instance_key = 3; + reserved 4; + // The common configuration for XRT execute operations. + CommonExecutionConfig common_config = 5; } // A single chained execute operation. An operation can either be a device data diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc index 1b5557d556d..46954572c5d 100644 --- a/tensorflow/compiler/xrt/xrt_device.cc +++ b/tensorflow/compiler/xrt/xrt_device.cc @@ -17,19 +17,56 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_device.h" +#include + #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { +namespace { + +class ResourceMgrArena { + public: + static ResourceMgrArena* Get() { + static ResourceMgrArena* arena = new ResourceMgrArena(); + return arena; + } + + ResourceMgr* GetResourceMgr(const std::string& platform_name) { + mutex_lock lock(mutex_); + auto it = resource_managers_.find(platform_name); + if (it == resource_managers_.end()) { + it = resource_managers_.emplace(platform_name, new ResourceMgr()).first; + } + return it->second; + } + + private: + mutex mutex_; + std::map resource_managers_; +}; + +} // namespace /*static*/ Status XRTGenericDeviceAccessor::GetResourceManager( OpKernelContext* ctx, ResourceMgr** rm) { - *rm = ctx->resource_manager(); + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + *rm = ResourceMgrArena::Get()->GetResourceMgr(metadata->platform()->Name()); return Status::OK(); } +/* static */ xla::StatusOr> +XRTGenericDeviceAccessor::GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries) { + ResourceMgr* rm; + TF_RETURN_IF_ERROR(GetResourceManager(ctx, &rm)); + return tensorflow::GetOrCreateCompilationCache(rm, max_number_of_entries); +} + /*static*/ Status XRTGenericDeviceAccessor::InitScopedRef( OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) { const XlaDevice::Metadata* metadata; diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h index 5ebee7641f0..02fab315830 100644 --- a/tensorflow/compiler/xrt/xrt_device.h +++ b/tensorflow/compiler/xrt/xrt_device.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_ #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -31,6 +32,9 @@ class XRTGenericDeviceAccessor { public: static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); + static xla::StatusOr> GetOrCreateCompilationCache( + OpKernelContext* ctx, int64 max_number_of_entries); + // We use a ScopedRef pattern here even though it's not strictly necessary, // just so that templated uses of this and the TPU accessor class will be as // similar as possible. diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc index 4d19d4b1226..b8a0afc92c5 100644 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -21,10 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { namespace { +mutex nccl_factory_mutex(LINKER_INITIALIZED); +std::shared_ptr* nccl_factory; + // The ScopedHandles data structure is used in the ExecuteChained() API and its // task is to track tuple allocation registrations. It is used both the track // intermediate results of a chained computation, or its final results. Anything @@ -162,6 +166,19 @@ Status PopulateOpWorkingSet(xla::Backend* backend, } // namespace +void SetNcclUniqueIdFactory(std::shared_ptr factory) { + mutex_lock lock(nccl_factory_mutex); + if (nccl_factory == nullptr) { + nccl_factory = new std::shared_ptr(); + } + *nccl_factory = std::move(factory); +} + +std::shared_ptr GetNcclUniqueIdFactory() { + mutex_lock lock(nccl_factory_mutex); + return nccl_factory != nullptr ? *nccl_factory : nullptr; +} + xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { static const bool options_passthrough = DebugOptionsPassThroughEnabled(); if (options_passthrough) { diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h index 32244a63081..cc1480fdb00 100644 --- a/tensorflow/compiler/xrt/xrt_util.h +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -18,6 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ #define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ +#include +#include +#include + #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -31,6 +35,19 @@ limitations under the License. namespace tensorflow { +// Factory class which creates NCCL unique IDs based on the replicas +// participating to a given communication. This is only used for GPU backends. +struct NcclUniqueIdFactory { + virtual ~NcclUniqueIdFactory() {} + + // Generates the NCCL unique ID for the given set of replica IDs. + virtual std::string GetUniqueId(absl::Span replicas) = 0; +}; + +void SetNcclUniqueIdFactory(std::shared_ptr factory); + +std::shared_ptr GetNcclUniqueIdFactory(); + struct InputCoords { explicit InputCoords(int64 handle) : handle(handle) {} InputCoords(int64 handle, xla::ShapeIndex index) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 19097fb8922..6b4874a8393 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -83,7 +83,6 @@ load( "tf_gen_op_libs", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_lite_protos", - "tf_opts_nortti_if_mobile", "tf_portable_full_lite_protos", "transitive_hdrs", ) @@ -100,28 +99,23 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") -# buildifier: disable=same-origin-load -# Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib") - # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_monitoring_deps") # For platform specific build config load( "//tensorflow/core/platform:build_config.bzl", - "tf_additional_all_protos", "tf_additional_lib_deps", "tf_additional_test_deps", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", "tf_lib_proto_parsing_deps", "tf_portable_deps_no_runtime", + "tf_portable_proto_lib", "tf_proto_library", - "tf_proto_library_cc", "tf_protos_all_impl", "tf_protos_grappler_impl", "tf_protos_profiler_impl", - "tf_pyclif_proto_library", ) load( "//tensorflow/core/platform:rules_cc.bzl", @@ -184,18 +178,18 @@ package_group(name = "friends") # filegroup; e.g. ones with individual proto_library targets. # LINT.IfChange COMMON_PROTO_SRCS = [ - "protobuf/bfc_memory_map.proto", - "protobuf/config.proto", - "protobuf/cluster.proto", - "protobuf/debug.proto", - "protobuf/device_filters.proto", - "protobuf/device_properties.proto", - "protobuf/graph_debug_info.proto", - "protobuf/queue_runner.proto", - "protobuf/rewriter_config.proto", - "protobuf/tensor_bundle.proto", - "protobuf/saver.proto", - "protobuf/verifier_config.proto", + "//tensorflow/core/protobuf:bfc_memory_map.proto", + "//tensorflow/core/protobuf:config.proto", + "//tensorflow/core/protobuf:cluster.proto", + "//tensorflow/core/protobuf:debug.proto", + "//tensorflow/core/protobuf:device_filters.proto", + "//tensorflow/core/protobuf:device_properties.proto", + "//tensorflow/core/protobuf:graph_debug_info.proto", + "//tensorflow/core/protobuf:queue_runner.proto", + "//tensorflow/core/protobuf:rewriter_config.proto", + "//tensorflow/core/protobuf:tensor_bundle.proto", + "//tensorflow/core/protobuf:saver.proto", + "//tensorflow/core/protobuf:verifier_config.proto", ] EXAMPLE_PROTO_SRCS = [ @@ -242,7 +236,7 @@ PROFILER_PROTO_SRCS = [ ] ERROR_CODES_PROTO_SRCS = [ - "protobuf/error_codes.proto", + "//tensorflow/core/protobuf:error_codes.proto", "//tensorflow/core/lib/core:error_codes.proto", ] # LINT.ThenChange(//tensorflow/core/portable_proto_config.asciipb) @@ -255,11 +249,13 @@ tf_proto_library( cc_api_version = 2, make_default_target_header_only = True, protodeps = [ - ":core_protos", - ":error_codes_proto_impl", "//tensorflow/core/example:protos_all", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", + "//tensorflow/core/profiler/protobuf:xplane_proto", + "//tensorflow/core/profiler:profiler_options_proto", + "//tensorflow/core/protobuf:error_codes_proto_impl", + "//tensorflow/core/protobuf:for_core_protos", "//tensorflow/core/util:protos_all", "//tensorflow/core/util:test_log_proto_impl", ], @@ -619,6 +615,7 @@ tf_gen_op_libs( "clustering_ops", "collective_ops", "control_flow_ops", + "count_ops", "ctc_ops", "data_flow_ops", "dataset_ops", @@ -847,6 +844,7 @@ cc_library( ":clustering_ops_op_lib", ":collective_ops_op_lib", ":control_flow_ops_op_lib", + ":count_ops_op_lib", ":ctc_ops_op_lib", ":cudnn_rnn_ops_op_lib", ":data_flow_ops_op_lib", @@ -889,23 +887,29 @@ cc_library( ":state_ops_op_lib", ":stateless_random_ops_op_lib", ":string_ops_op_lib", - ":tpu_configuration_ops_op_lib", - ":tpu_cross_replica_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ":tpu_embedding_load_retrieve_ops_op_lib", - ":tpu_functional_ops_op_lib", - ":tpu_heartbeat_ops_op_lib", - ":tpu_host_compute_ops_op_lib", - ":tpu_infeed_ops_op_lib", - ":tpu_outfeed_ops_op_lib", - ":tpu_ordinal_selector_ops_op_lib", - ":tpu_replication_ops_op_lib", ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", "//tensorflow/c/kernels:bitcast_op_lib", "//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op", - ] + if_mkl([ + ] + if_chromiumos( + [], + # Non-tpu platforms don't need tpu dependency. It would be best to guard + # them by if_tpu. But there is no such flag yet. + [ + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_embedding_load_retrieve_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", + ], + ) + if_mkl([ ":mkl_array_ops_op_lib", ":mkl_nn_ops_op_lib", ]) + if_tensorrt([ @@ -1006,6 +1010,7 @@ cc_library( "//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:count_ops", "//tensorflow/core/kernels:ctc_ops", "//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:decode_proto_op", @@ -1265,7 +1270,7 @@ filegroup( "//tensorflow/core/platform:mobile_srcs_no_runtime", "//tensorflow/core/public:mobile_srcs_no_runtime", "//tensorflow/core/util:mobile_srcs_no_runtime", - "//tensorflow/core/util/ctc:android_srcs", + "//tensorflow/core/util/ctc:mobile_srcs", ] + glob( [ "client/**/*.cc", @@ -1295,12 +1300,12 @@ filegroup( "//tensorflow/core/common_runtime/eager:srcs", "//tensorflow/core/framework:mobile_srcs_only_runtime", "//tensorflow/core/graph:mobile_srcs_only_runtime", - "//tensorflow/core/kernels:android_srcs", + "//tensorflow/core/kernels:mobile_srcs", "//tensorflow/core/lib/io:mobile_srcs_only_runtime", "//tensorflow/core/profiler:mobile_srcs", "//tensorflow/core/public:mobile_srcs_only_runtime", "//tensorflow/core/util/sparse:mobile_srcs_only_runtime", - "//tensorflow/core/util/tensor_bundle:android_srcs", + "//tensorflow/core/util/tensor_bundle:mobile_srcs", "//tensorflow/core/util:mobile_srcs_only_runtime", # Sources for which we already have granular targets. @@ -1365,10 +1370,7 @@ cc_library( name = "portable_tensorflow_lib_lite", srcs = if_mobile([":mobile_srcs"]), copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), - defines = ["SUPPORT_SELECTIVE_REGISTRATION"] + tf_portable_full_lite_protos( - full = [], - lite = ["TENSORFLOW_LITE_PROTOS"], - ) + if_chromiumos(["IS_MOBILE_PLATFORM"]) + tf_defines_nortti_if_lite_protos(), + defines = ["SUPPORT_SELECTIVE_REGISTRATION"] + if_chromiumos(["IS_MOBILE_PLATFORM"]) + tf_defines_nortti_if_lite_protos(), linkopts = if_android(["-lz"]) + if_ios(["-lz"]), tags = [ "manual", @@ -1376,10 +1378,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - ":protos_all_cc_impl", "//tensorflow/core/util:stats_calculator_portable", "//tensorflow/core:mobile_additional_lib_deps", - ] + tf_portable_deps_no_runtime(), + ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, ) @@ -1411,54 +1412,12 @@ cc_library( ], ) -# Native library support for iOS applications. -# -# bazel build --config=ios_x86_64 \ -# :ios_tensorflow_lib -cc_library( - name = "ios_tensorflow_lib", - srcs = if_ios([ - ":portable_op_registrations_and_gradients", - "//tensorflow/core/kernels:android_core_ops", - "//tensorflow/core/kernels:android_extended_ops", - ]), - copts = tf_copts() + tf_opts_nortti_if_lite_protos() + ["-Os"], - visibility = ["//visibility:public"], - deps = [ - ":portable_tensorflow_lib_lite", - ":protos_all_cc_impl", - "//third_party/eigen3", - "//third_party/fft2d:fft2d_headers", - "@com_google_protobuf//:protobuf", - "@fft2d", - "@gemmlowp", - ], - alwayslink = 1, -) - alias( name = "ios_tensorflow_lib_lite", actual = ":portable_tensorflow_lib_lite", visibility = ["//visibility:public"], ) -cc_library( - name = "ios_tensorflow_test_lib", - testonly = 1, - srcs = if_ios([":android_test_srcs"]), - copts = tf_copts() + ["-Os"], - tags = [ - "manual", - "notap", - ], - visibility = ["//visibility:public"], - deps = [ - ":ios_tensorflow_lib", - "//tensorflow/core/platform/default/build_config:gtest", - "//third_party/eigen3", - ], -) - # Full TensorFlow library with operator support. Use this unless reducing # binary size (by packaging a reduced operator set) is a concern. alias( @@ -1467,10 +1426,16 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "ios_tensorflow_lib", + actual = ":portable_tensorflow_lib", + visibility = ["//visibility:public"], +) + cc_library( name = "portable_tensorflow_lib", srcs = if_mobile([":portable_op_registrations_and_gradients"]), - copts = tf_copts() + tf_opts_nortti_if_lite_protos(), + copts = tf_copts() + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), features = tf_features_nomodules_if_mobile(), tags = [ "manual", @@ -1553,6 +1518,12 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "ios_tensorflow_test_lib", + actual = ":portable_tensorflow_test_lib", + visibility = ["//visibility:public"], +) + cc_library( name = "portable_tensorflow_test_lib", testonly = 1, @@ -1563,7 +1534,7 @@ cc_library( "//tensorflow/core/framework:android_test_hdrs", "//tensorflow/core/util:android_test_hdrs", ], - copts = tf_copts(android_optimization_level_override = None), + copts = tf_copts(android_optimization_level_override = None) + if_ios(["-Os"]), features = tf_features_nomodules_if_mobile() + tf_opts_nortti_if_lite_protos(), tags = [ "manual", @@ -1631,20 +1602,13 @@ alias( [ alias( name = "protobuf_%s_pyclif%s" % (proto_name, target_suffix), - actual = ":protobuf/%s_pyclif%s" % (proto_name, target_suffix), + actual = "//tensorflow/core/protobuf:%s_pyclif%s" % (proto_name, target_suffix), visibility = ["//visibility:public"], ) for target_suffix in [ "", "_pb2", ] - ] + [ - tf_pyclif_proto_library( - name = "protobuf/%s_pyclif" % proto_name, - proto_lib = ":protos_all", - proto_srcfile = "protobuf/%s.proto" % proto_name, - visibility = ["//visibility:public"], - ), ] for proto_name in [ "config", @@ -1658,77 +1622,74 @@ alias( # ----------------------------------------------------------------------------- # Internal targets -tf_proto_library( +alias( name = "autotuning_proto", - srcs = ["protobuf/autotuning.proto"], - cc_api_version = 2, - make_default_target_header_only = True, + actual = "//tensorflow/core/protobuf:autotuning_proto", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library( +alias( + name = "autotuning_proto_cc", + actual = "//tensorflow/core/protobuf:autotuning_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( name = "conv_autotuning_proto", - srcs = ["protobuf/conv_autotuning.proto"], - cc_api_version = 2, - make_default_target_header_only = True, - protodeps = [ - "//tensorflow/stream_executor:dnn_proto", - ], + actual = "//tensorflow/core/protobuf:conv_autotuning_proto", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "worker_proto", - srcs = ["protobuf/worker.proto"], - cc_api_version = 2, - protodeps = tf_additional_all_protos(), - visibility = ["//visibility:public"], -) - -tf_proto_library_cc( - name = "worker_service_proto", - srcs = ["protobuf/worker_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_stubby_versions = ["2"], - protodeps = [":worker_proto"], +alias( + name = "conv_autotuning_proto_cc", + actual = "//tensorflow/core/protobuf:conv_autotuning_proto_cc", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "master_proto", - srcs = ["protobuf/master.proto"], - cc_api_version = 2, - protodeps = tf_additional_all_protos(), - visibility = ["//tensorflow:internal"], -) - -tf_proto_library_cc( - name = "master_service_proto", - srcs = ["protobuf/master_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_stubby_versions = ["2"], - protodeps = [":master_proto"], +alias( + name = "worker_proto_cc", + actual = "//tensorflow/core/protobuf:worker_proto_cc", visibility = [ "//tensorflow:internal", ], ) -tf_proto_library_cc( - name = "eager_service_proto", - srcs = ["protobuf/eager_service.proto"], - has_services = 1, - cc_api_version = 2, - cc_grpc_version = 1, - cc_stubby_versions = ["2"], - protodeps = tf_additional_all_protos(), +alias( + name = "worker_service_proto_cc", + actual = "//tensorflow/core/protobuf:worker_service_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( + name = "master_proto_cc", + actual = "//tensorflow/core/protobuf:master_proto_cc", + visibility = [ + "//learning/brain/frameworks/uptc:__subpackages__", + "//tensorflow:internal", + ], +) + +alias( + name = "master_service_proto_cc", + actual = "//tensorflow/core/protobuf:master_service_proto_cc", + visibility = [ + "//tensorflow:internal", + ], +) + +alias( + name = "eager_service_proto_cc", + actual = "//tensorflow/core/protobuf:eager_service_proto_cc", visibility = [ "//tensorflow:internal", ], @@ -2140,49 +2101,14 @@ cc_library( ], ) -tf_proto_library( +alias( name = "error_codes_proto_impl", - srcs = ["protobuf/error_codes.proto"], - cc_api_version = 2, - make_default_target_header_only = True, + actual = "//tensorflow/core/protobuf:error_codes_proto_impl", ) -tf_proto_library( - name = "core_protos", - srcs = COMMON_PROTO_SRCS + [ - # Protos which are not needed on mobile builds, but should be included - # in protos_all. - # - # Note that some protos are in neither core_proto_srcs nor this - # filegroup; e.g. ones with individual proto_library targets. - "protobuf/control_flow.proto", - # TODO(ebrevdo): Re-enable once CriticalSection is in core. - # "protobuf/critical_section.proto", - "protobuf/data/experimental/snapshot.proto", - "protobuf/debug_event.proto", - "protobuf/meta_graph.proto", - "protobuf/named_tensor.proto", - "protobuf/remote_tensor_handle.proto", - "protobuf/saved_model.proto", - "protobuf/saved_object_graph.proto", - "protobuf/struct.proto", - "protobuf/tensorflow_server.proto", - "protobuf/trackable_object_graph.proto", - "protobuf/transport_options.proto", - ], - cc_api_version = 2, - make_default_target_header_only = True, - protodeps = [ - ":error_codes_proto_impl", - "//tensorflow/core/example:protos_all", - "//tensorflow/core/framework:protos_all", - "//tensorflow/core/lib/core:error_codes_proto", - "//tensorflow/core/profiler/protobuf:xplane_proto", - "//tensorflow/core/profiler:profiler_options_proto", - "//tensorflow/core/util:protos_all", - "//tensorflow/core/util:test_log_proto_impl", - ], - visibility = ["//visibility:private"], +alias( + name = "error_codes_proto_impl_cc", + actual = "//tensorflow/core/protobuf:error_codes_proto_impl_cc", ) alias( @@ -2391,10 +2317,6 @@ alias( # Library containing all of the graph construction code that is # independent of the runtime. -# -# TODO(mrry): Refactor graph_constructor.cc so that it does not depend on code -# in "common_runtime/", and then the entire "graph/" directory can be included -# in this library. tf_cuda_library( name = "graph", srcs = ["//tensorflow/core/graph:graph_srcs"], @@ -2478,13 +2400,9 @@ alias( visibility = ["//visibility:public"], ) -tf_proto_library_cc( - name = "replay_log_proto", - srcs = ["protobuf/replay_log.proto"], - cc_api_version = 2, - protodeps = [ - ":master_proto", - ] + tf_additional_all_protos(), +alias( + name = "replay_log_proto_cc", + actual = "//tensorflow/core/protobuf:replay_log_proto_cc", visibility = [ "//tensorflow:internal", ], @@ -2740,42 +2658,6 @@ tf_cc_tests( ], ) -tf_cc_tests( - name = "higher_level_tests_needing_kernels", - size = "small", - srcs = [ - "//tensorflow/core/graph:higher_level_tests_needing_kernels", - ], - linkopts = select({ - "//tensorflow:macos": ["-headerpad_max_install_names"], - "//conditions:default": [], - }), - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":all_kernels", - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:scope", - "//tensorflow/cc:sendrecv_ops", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/util:protos_test_cc", - "//third_party/eigen3", - ], -) - tf_cc_test( name = "cudnn_rnn_ops_test_cc", size = "small", @@ -3151,6 +3033,11 @@ alias( actual = "//tensorflow/core/platform:cuda_libdevice_path", ) +# Normalize CORE_PROTO_SRCS to generate valid output file names. +PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ + "//google/protobuf/any.proto.h", +] + transitive_hdrs( name = "headers", visibility = ["//tensorflow:__subpackages__"], @@ -3163,8 +3050,3 @@ transitive_hdrs( "//tensorflow/core/platform:platform_strings", ], ) - -# Normalize CORE_PROTO_SRCS to generate valid output file names. -PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ - "//google/protobuf/any.proto.h", -] diff --git a/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt b/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt index bfaf6768601..c34b5c6fbcb 100644 --- a/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_AdjustHue.pbtxt @@ -21,7 +21,7 @@ END summary: "Adjust the hue of one or more images." description: < l1 else 0.0 accum = accum_new diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt index 3218ab7776c..1eb33005e91 100644 --- a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt @@ -65,8 +65,8 @@ END summary: "Update \'*var\' according to the Ftrl-proximal scheme." description: < l1 else 0.0 diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt index 2bbaba26257..84382d8a99c 100644 --- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt @@ -47,7 +47,7 @@ END in_arg { name: "min_node_weight" description: <